mirror of https://github.com/llvm/torch-mlir
Merge remote-tracking branch 'upstream/main' into emit_detach_
commit
99e2f143f0
|
@ -11,7 +11,7 @@ repos:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 22.10.0
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
---
|
||||||
|
|
||||||
|
extends: default
|
||||||
|
|
||||||
|
rules:
|
||||||
|
# These do not appear to be conventional in GitHub actions.
|
||||||
|
document-end:
|
||||||
|
present: false
|
||||||
|
document-start:
|
||||||
|
present: false
|
||||||
|
# GitHub actions use "on" for triggers.
|
||||||
|
truthy: disable
|
||||||
|
# We have lots of long strings and command lines.
|
||||||
|
line-length: disable
|
||||||
|
comments:
|
||||||
|
# Formatters may do this (e.g. Prettier does) and it seems like the most
|
||||||
|
# trivial thing to get a failing check for.
|
||||||
|
min-spaces-from-content: 1
|
||||||
|
# This is not a useful check, especially when disabling entire blocks.
|
||||||
|
comments-indentation: disable
|
||||||
|
|
||||||
|
ignore: /third_party/*
|
|
@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
|
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host \
|
-DLLVM_TARGETS_TO_BUILD=host \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
|
-DTORCH_MLIR_ENABLE_LTC=ON
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
echo "::group::Build"
|
echo "::group::Build"
|
||||||
|
|
|
@ -432,6 +432,8 @@ function clean_build() {
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_torch_mlir() {
|
function build_torch_mlir() {
|
||||||
|
# Disable LTC build for releases
|
||||||
|
export TORCH_MLIR_ENABLE_LTC=0
|
||||||
local torch_version="$1"
|
local torch_version="$1"
|
||||||
case $torch_version in
|
case $torch_version in
|
||||||
nightly)
|
nightly)
|
||||||
|
@ -440,7 +442,7 @@ function build_torch_mlir() {
|
||||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
--extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
CMAKE_GENERATOR=Ninja \
|
CMAKE_GENERATOR=Ninja \
|
||||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \
|
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \
|
||||||
-r /main_checkout/torch-mlir/whl-requirements.txt
|
-r /main_checkout/torch-mlir/whl-requirements.txt
|
||||||
;;
|
;;
|
||||||
|
@ -450,7 +452,7 @@ function build_torch_mlir() {
|
||||||
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||||
CMAKE_GENERATOR=Ninja \
|
CMAKE_GENERATOR=Ninja \
|
||||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
|
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unrecognized torch version '$torch_version'"
|
echo "Unrecognized torch version '$torch_version'"
|
||||||
|
@ -474,7 +476,7 @@ function build_torch_mlir_core() {
|
||||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
|
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
|
||||||
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
|
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
|
||||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
|
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
|
||||||
}
|
}
|
||||||
|
|
||||||
function clean_wheels() {
|
function clean_wheels() {
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
See https://github.com/llvm/torch-mlir/issues/1374
|
See https://github.com/llvm/torch-mlir/issues/1374
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit a952c123880eb1168f1021b116485e27170d48ca
|
Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf
|
|
@ -1 +1 @@
|
||||||
Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97
|
Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91
|
|
@ -30,8 +30,6 @@ namespace detail {
|
||||||
LogicalResult verifyTMTensorOpInterface(Operation *op);
|
LogicalResult verifyTMTensorOpInterface(Operation *op);
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
|
||||||
|
|
||||||
/// Include the generated interface declarations.
|
/// Include the generated interface declarations.
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
|
||||||
|
|
||||||
|
@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op);
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||||
|
|
||||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||||
|
|
|
@ -97,6 +97,31 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
|
||||||
|
for (auto result : op->getResults()) {
|
||||||
|
auto t = toValidTensorType(result.getType());
|
||||||
|
if (!t)
|
||||||
|
return failure();
|
||||||
|
typeList.push_back(t);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||||
|
// op.
|
||||||
|
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||||
|
if (idx >= op->getNumRegions())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
region = &op->getRegion(idx);
|
||||||
|
|
||||||
|
if (region == nullptr) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
||||||
int64_t idx) {
|
int64_t idx) {
|
||||||
if (idx >= op->getNumResults())
|
if (idx >= op->getNumResults())
|
||||||
|
|
|
@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder,
|
||||||
|
|
||||||
Type getQTorchTypeFromTorchIntType(Type ty);
|
Type getQTorchTypeFromTorchIntType(Type ty);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||||
|
Value &ofItem) {
|
||||||
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||||
|
rewriter.getType<T>(), ofItem);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult OnnxLstmExpander(OpBinder binder,
|
LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
ConversionPatternRewriter &rewriter);
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||||
size_t dimSizeIndexBits);
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
|
// Get a tensor that collapse the specified dimensions of the input tensor
|
||||||
|
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t collapseStartDim,
|
||||||
|
int64_t collapseEndDim,
|
||||||
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
|
// Get a tensor that splits the specified dimensions of the input tensor
|
||||||
|
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t splitDim,
|
||||||
|
int64_t outerLength, size_t dimSizeIndexBits);
|
||||||
|
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType);
|
TensorType outType);
|
||||||
|
|
|
@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$alpha
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenCeluOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
AnyTorchScalarType:$alpha
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalNonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenCelu_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
|
@ -6590,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$kernel_size,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchListOfTorchIntType:$padding,
|
||||||
|
AnyTorchListOfTorchIntType:$dilation,
|
||||||
|
Torch_BoolType:$ceil_mode
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -6645,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
|
||||||
printDefaultTorchOp(printer, *this, 6, 2);
|
printDefaultTorchOp(printer, *this, 6, 2);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
||||||
|
@ -13601,6 +13677,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchListOfTorchStringType:$l,
|
||||||
|
Torch_StringType:$item
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
|
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -15904,6 +16005,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
||||||
|
|
|
@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
|
||||||
return detail::torch_list_of_constant_bools_op_binder(bind_values);
|
return detail::torch_list_of_constant_bools_op_binder(bind_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
/// Matches the constant strs stored in a `torch.ListConstruct`.
|
||||||
|
struct torch_list_of_constant_strs_op_binder {
|
||||||
|
SmallVectorImpl<std::string> &bind_values;
|
||||||
|
|
||||||
|
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||||
|
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
|
||||||
|
: bind_values(bvs) {}
|
||||||
|
|
||||||
|
bool match(Operation *op) {
|
||||||
|
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||||
|
if (!listConstruct)
|
||||||
|
return false;
|
||||||
|
for (Value value : listConstruct.getElements()) {
|
||||||
|
std::string str;
|
||||||
|
if (matchPattern(value, m_TorchConstantStr(str)))
|
||||||
|
bind_values.push_back(str);
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
|
||||||
|
inline detail::torch_list_of_constant_strs_op_binder
|
||||||
|
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
|
||||||
|
return detail::torch_list_of_constant_strs_op_binder(bind_values);
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
||||||
struct torch_tensor_size_int_op_binder {
|
struct torch_tensor_size_int_op_binder {
|
||||||
|
|
|
@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
LogicalResult windowFunctionImpl(OpBinder binder,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value size, Value a0, Value a1, Value a2,
|
||||||
|
Torch::ValueTensorType resultType,
|
||||||
|
int64_t output_datatype, int64_t periodic) {
|
||||||
|
|
||||||
|
Location loc = binder.getLoc();
|
||||||
|
ImplicitLocOpBuilder b(loc, rewriter);
|
||||||
|
|
||||||
|
double isPeriodicFp = static_cast<double>(periodic);
|
||||||
|
|
||||||
|
Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
|
||||||
|
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));
|
||||||
|
|
||||||
|
constexpr double pi = llvm::numbers::pi;
|
||||||
|
Value tau = b.create<Torch::ConstantFloatOp>(
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
|
||||||
|
|
||||||
|
Value noneVal = b.create<Torch::ConstantNoneOp>();
|
||||||
|
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
|
||||||
|
Value float32Type = b.create<Torch::ConstantIntOp>(
|
||||||
|
rewriter.getI64IntegerAttr(/*float32Type*/ 6));
|
||||||
|
|
||||||
|
// Create an f32 ValueTensorType with thse same size as size, the
|
||||||
|
// operand
|
||||||
|
auto shapeOfOperand =
|
||||||
|
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
|
||||||
|
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
shapeOfOperand, rewriter.getF32Type());
|
||||||
|
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
|
||||||
|
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
|
||||||
|
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
|
||||||
|
periodicSizeFloat.getType(), periodicSizeFloat, one, one);
|
||||||
|
|
||||||
|
Value isPeriodic =
|
||||||
|
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
|
||||||
|
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
|
||||||
|
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
|
||||||
|
|
||||||
|
Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
|
||||||
|
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
|
||||||
|
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
|
||||||
|
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
|
||||||
|
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
|
||||||
|
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);
|
||||||
|
|
||||||
|
// Here, size can be used in the place of periodicSizeFloat, as the
|
||||||
|
// latter is just a float representation of the former.
|
||||||
|
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
|
||||||
|
|
||||||
|
Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
|
||||||
|
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);
|
||||||
|
|
||||||
|
Value rangeTimesTau =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
|
||||||
|
Value rangeAngular =
|
||||||
|
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
|
||||||
|
Value twoRangeAngular =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);
|
||||||
|
|
||||||
|
Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
|
||||||
|
Value cosTwoRangeAngular =
|
||||||
|
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);
|
||||||
|
|
||||||
|
Value a1Component =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
|
||||||
|
Value a2Component =
|
||||||
|
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);
|
||||||
|
|
||||||
|
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
|
||||||
|
// is, operand #1. Therefore, to avoid errors, the onnx implementation
|
||||||
|
// has been modified. a1 has been changed to negative half, and the
|
||||||
|
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
|
||||||
|
// operation is commutative.
|
||||||
|
Value subA1Component =
|
||||||
|
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
|
||||||
|
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
|
||||||
|
a2Component, one);
|
||||||
|
|
||||||
|
std::optional<int64_t> dtypeIntTorch =
|
||||||
|
onnxDtypeIntToTorchDtypeInt(output_datatype);
|
||||||
|
if (!dtypeIntTorch.has_value()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented support for the given dtype conversion");
|
||||||
|
}
|
||||||
|
Value outputDtype = b.create<Torch::ConstantIntOp>(
|
||||||
|
rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
dtypeIntTorch.value()));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||||
|
binder.op, resultType, result, outputDtype,
|
||||||
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||||
|
/*memory_format=*/noneVal);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Simple rewrites for the default domain.
|
// Simple rewrites for the default domain.
|
||||||
// See: https://onnx.ai/onnx/operators/
|
// See: https://onnx.ai/onnx/operators/
|
||||||
// For operators that are effectively version invariant, we register with
|
// For operators that are effectively version invariant, we register with
|
||||||
|
@ -2186,4 +2288,65 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
|
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"BlackmanWindow", 17,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value size;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
int64_t periodic, output_datatype;
|
||||||
|
if (binder.tensorOperand(size) ||
|
||||||
|
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||||
|
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
|
||||||
|
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||||
|
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
|
||||||
|
|
||||||
|
auto windowFunctionResult =
|
||||||
|
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||||
|
output_datatype, periodic);
|
||||||
|
|
||||||
|
if (failed(windowFunctionResult))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
|
patterns.onOp(
|
||||||
|
"HannWindow", 17,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value size;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
int64_t periodic, output_datatype;
|
||||||
|
if (binder.tensorOperand(size) ||
|
||||||
|
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||||
|
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
|
||||||
|
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||||
|
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
||||||
|
|
||||||
|
auto windowFunctionResult =
|
||||||
|
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||||
|
output_datatype, periodic);
|
||||||
|
|
||||||
|
if (failed(windowFunctionResult))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
alignCorners);
|
alignCorners);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value conditionTensor;
|
||||||
|
if (binder.tensorOperand(conditionTensor)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"condition bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto conditionType =
|
||||||
|
conditionTensor.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (!conditionType || conditionType.getSizes().size() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "condition must have one single element per "
|
||||||
|
"https://onnx.ai/onnx/operators/onnx__If.html");
|
||||||
|
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
conditionTensor);
|
||||||
|
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Type> resultTypes;
|
||||||
|
if (binder.tensorResultTypes(resultTypes)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"result type bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
Region *thenRegion, *elseRegion;
|
||||||
|
if (binder.getRegionAtIndex(elseRegion, 0) ||
|
||||||
|
binder.getRegionAtIndex(thenRegion, 1)) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
|
||||||
|
binder.getLoc(), TypeRange(resultTypes), conditionBool);
|
||||||
|
|
||||||
|
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||||
|
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||||
|
};
|
||||||
|
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
|
||||||
|
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
|
||||||
|
|
||||||
|
auto replaceTerminator = [&](Region ®ion) {
|
||||||
|
PatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
Operation *terminator = region.front().getTerminator();
|
||||||
|
rewriter.setInsertionPoint(terminator);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
|
||||||
|
terminator, terminator->getOperands());
|
||||||
|
};
|
||||||
|
replaceTerminator(primIfOp.getThenRegion());
|
||||||
|
replaceTerminator(primIfOp.getElseRegion());
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, primIfOp.getResults());
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("Less", 13,
|
patterns.onOp("Less", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c;
|
||||||
// thing here, so we simplify.
|
// thing here, so we simplify.
|
||||||
|
|
||||||
// utilities
|
// utilities
|
||||||
// Templatized function to get an item op of a type
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T>
|
|
||||||
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
|
||||||
Value &ofItem) {
|
|
||||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
|
||||||
rewriter.getType<T>(), ofItem);
|
|
||||||
}
|
|
||||||
|
|
||||||
// In case the ReduceSum Op was not the first operation performed on the data,
|
// In case the ReduceSum Op was not the first operation performed on the data,
|
||||||
// we provide the original operand through storeResult, which will be modified
|
// we provide the original operand through storeResult, which will be modified
|
||||||
// if the result will be passed onto another operation, and will be used for
|
// if the result will be passed onto another operation, and will be used for
|
||||||
|
@ -847,12 +839,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
float alpha, gamma;
|
float alpha, gamma;
|
||||||
Value operand;
|
Value operand;
|
||||||
|
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
|
||||||
|
// alpha and gamma values.
|
||||||
if (binder.tensorOperand(operand) ||
|
if (binder.tensorOperand(operand) ||
|
||||||
binder.f32FloatAttr(alpha, "alpha") ||
|
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
|
||||||
binder.f32FloatAttr(gamma, "gamma") ||
|
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -945,22 +940,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
/*memory_format=*/noneVal);
|
/*memory_format=*/noneVal);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("ReduceSum", 1,
|
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
||||||
Torch::ValueTensorType resultType;
|
|
||||||
Value data;
|
|
||||||
int64_t keepDims, noop_with_empty_axes;
|
|
||||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
|
||||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
||||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
||||||
"noop_with_empty_axes", 0))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return reducedSumImpl(binder, rewriter, data, resultType,
|
|
||||||
/*storeValue=*/data, keepDims,
|
|
||||||
noop_with_empty_axes, false);
|
|
||||||
});
|
|
||||||
patterns.onOp("ReduceLogSum", 1,
|
patterns.onOp("ReduceLogSum", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
@ -987,6 +966,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, data);
|
binder.op, resultType, data);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp("ReduceSum", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
int64_t keepDims, noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||||
|
"noop_with_empty_axes", 0))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||||
|
/*storeValue=*/data, keepDims,
|
||||||
|
noop_with_empty_axes, false);
|
||||||
|
});
|
||||||
|
patterns.onOp("ReduceSumSquare", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
int64_t keepDims, noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||||
|
"noop_with_empty_axes", 0))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
|
||||||
|
binder.getLoc(), data.getType(), data, data);
|
||||||
|
|
||||||
|
return reducedSumImpl(binder, rewriter, dataSquare,
|
||||||
|
resultType,
|
||||||
|
/*storeValue=*/data, keepDims,
|
||||||
|
noop_with_empty_axes, false);
|
||||||
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"ReduceMean", 1,
|
"ReduceMean", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
||||||
if (!isUnsignedType)
|
if (!isUnsignedType)
|
||||||
return;
|
return;
|
||||||
int64_t minSI = -(1 << (numBits - 1));
|
int64_t minSI = -(1 << (numBits - 1));
|
||||||
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
|
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
||||||
|
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
|
||||||
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
||||||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
|
@ -797,6 +798,8 @@ public:
|
||||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||||
|
|
||||||
Value inputZp, weightZp;
|
Value inputZp, weightZp;
|
||||||
|
bool inputUnsigned = false;
|
||||||
|
bool weightUnsigned = false;
|
||||||
if (auto make = op.getInput()
|
if (auto make = op.getInput()
|
||||||
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||||
input = make.getSelf();
|
input = make.getSelf();
|
||||||
|
@ -806,6 +809,8 @@ public:
|
||||||
inputZp = typeConverter->materializeTargetConversion(
|
inputZp = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(inputZp.getType()),
|
rewriter, loc, typeConverter->convertType(inputZp.getType()),
|
||||||
inputZp);
|
inputZp);
|
||||||
|
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
||||||
|
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto make = op.getWeight()
|
if (auto make = op.getWeight()
|
||||||
|
@ -818,6 +823,8 @@ public:
|
||||||
weightZp = typeConverter->materializeTargetConversion(
|
weightZp = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(weightZp.getType()),
|
rewriter, loc, typeConverter->convertType(weightZp.getType()),
|
||||||
weightZp);
|
weightZp);
|
||||||
|
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
||||||
|
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
|
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
|
||||||
|
@ -916,15 +923,35 @@ public:
|
||||||
SmallVector<Value> strideIntValues =
|
SmallVector<Value> strideIntValues =
|
||||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||||
|
|
||||||
|
// convert any uint8 quantization to int8 quantization
|
||||||
|
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
|
||||||
|
int64_t width = integerType.getWidth();
|
||||||
|
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
|
||||||
|
}
|
||||||
|
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
|
||||||
|
int64_t width = integerType.getWidth();
|
||||||
|
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
|
||||||
|
}
|
||||||
// Pad the input tensor according to padding.
|
// Pad the input tensor according to padding.
|
||||||
SmallVector<Value> outDims{inBatch, weightBatch};
|
SmallVector<Value> outDims{inBatch, weightBatch};
|
||||||
Value paddedInput;
|
Value paddedInput;
|
||||||
if (transposed) {
|
Value pad = inputZp;
|
||||||
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
|
if (!pad) {
|
||||||
!isa<mlir::FloatType>(resultDTy))
|
if (isa<mlir::FloatType>(inputDTy))
|
||||||
return rewriter.notifyMatchFailure(
|
pad = rewriter.create<arith::ConstantOp>(
|
||||||
op, "transpose does not support non-fp type yet");
|
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
|
||||||
|
if (isa<mlir::IntegerType>(inputDTy))
|
||||||
|
pad = rewriter.create<arith::ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
|
||||||
|
}
|
||||||
|
if (pad.getType() != inputDTy) {
|
||||||
|
if (isa<mlir::FloatType>(inputDTy))
|
||||||
|
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
|
||||||
|
|
||||||
|
if (isa<mlir::IntegerType>(inputDTy))
|
||||||
|
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
|
||||||
|
}
|
||||||
|
if (transposed) {
|
||||||
Value c0 =
|
Value c0 =
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||||
Value c1 =
|
Value c1 =
|
||||||
|
@ -994,7 +1021,7 @@ public:
|
||||||
|
|
||||||
// Allocate padded input tensor
|
// Allocate padded input tensor
|
||||||
Value initTensor =
|
Value initTensor =
|
||||||
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy);
|
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
|
||||||
|
|
||||||
// Insert input into allocated tensor
|
// Insert input into allocated tensor
|
||||||
SmallVector<Value> strideIndexValues{c1, c1};
|
SmallVector<Value> strideIndexValues{c1, c1};
|
||||||
|
@ -1017,24 +1044,6 @@ public:
|
||||||
strideInts.clear();
|
strideInts.clear();
|
||||||
strideInts.append(numSpatialDims, 1);
|
strideInts.append(numSpatialDims, 1);
|
||||||
} else {
|
} else {
|
||||||
Value pad = inputZp;
|
|
||||||
if (!pad) {
|
|
||||||
if (isa<mlir::FloatType>(inputDTy))
|
|
||||||
pad = rewriter.create<arith::ConstantOp>(
|
|
||||||
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
|
|
||||||
if (isa<mlir::IntegerType>(inputDTy))
|
|
||||||
pad = rewriter.create<arith::ConstantOp>(
|
|
||||||
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pad.getType() != inputDTy) {
|
|
||||||
if (isa<mlir::FloatType>(inputDTy))
|
|
||||||
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
|
|
||||||
|
|
||||||
if (isa<mlir::IntegerType>(inputDTy))
|
|
||||||
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||||
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
|
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
|
||||||
|
|
|
@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
auto constType = RankedTensorType::get({}, elementTy);
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
// Avg pooling
|
// Avg pooling
|
||||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||||
AtenCumsumOp>(op)) {
|
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
|
||||||
if (isa<mlir::FloatType>(elementTy)) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
|
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max pooling
|
// Max pooling
|
||||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||||
|
AtenMaxPool2dWithIndicesOp>(op)) {
|
||||||
if (isa<mlir::FloatType>(elementTy)) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// AtenMaxPool2dOp
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|
||||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
Value input = adaptor.getSelf();
|
|
||||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
|
||||||
auto inputElemTy = inputTy.getElementType();
|
|
||||||
|
|
||||||
auto inputRank = inputTy.getRank();
|
|
||||||
auto outTy =
|
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
||||||
|
|
||||||
if (inputRank <= 2) {
|
|
||||||
return op.emitError(
|
|
||||||
"max_pooling2d only supports inputs with rank higher than 2");
|
|
||||||
}
|
|
||||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
|
||||||
bool ceilMode = false;
|
|
||||||
|
|
||||||
if (!(matchPattern(op.getKernelSize(),
|
|
||||||
m_TorchListOfConstantInts(kernelSize)))) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "non-const int kernel size unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const int padding unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const int dilation unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const bool ceil_mode unsupported!");
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
|
||||||
// input
|
|
||||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
|
||||||
std::copy(dilation.begin(), dilation.end(),
|
|
||||||
stablehloDilation.begin() + inputRank - 2);
|
|
||||||
std::copy(stride.begin(), stride.end(),
|
|
||||||
stablehloStride.begin() + inputRank - 2);
|
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
|
||||||
stablehloKernelSize.begin() + inputRank - 2);
|
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
||||||
|
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
|
||||||
|
|
||||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
||||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
||||||
DenseI64ArrayAttr baseDilations;
|
|
||||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(
|
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
||||||
rewriter.getI64Type()),
|
|
||||||
stablehloPadding);
|
|
||||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
|
||||||
baseDilations, windowDilations, pad);
|
|
||||||
|
|
||||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
|
||||||
|
|
||||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
|
||||||
|
|
||||||
auto *firstArg = block.args_begin();
|
|
||||||
auto secondArg = block.args_rbegin();
|
|
||||||
|
|
||||||
{
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(&block);
|
|
||||||
Value result =
|
|
||||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// AtenMaxPool2dWithIndicesOp
|
// AtenMaxPool2dWithIndicesOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOpT, int Dim>
|
||||||
|
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
auto inputRank = inputTy.getRank();
|
||||||
|
auto outTy = cast<RankedTensorType>(
|
||||||
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
|
if (inputRank <= Dim) {
|
||||||
|
return op.emitError(
|
||||||
|
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
|
||||||
|
bool ceilMode = false;
|
||||||
|
|
||||||
|
if (!(matchPattern(op.getKernelSize(),
|
||||||
|
m_TorchListOfConstantInts(kernelSize)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const int kernel size unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int stride unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int padding unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilation)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int dilation unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const bool ceil_mode unsupported!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stride.empty()) {
|
||||||
|
stride = kernelSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||||
|
// as input
|
||||||
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
|
stablehloDilation.begin() + inputRank - Dim);
|
||||||
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - Dim);
|
||||||
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
|
stablehloKernelSize.begin() + inputRank - Dim);
|
||||||
|
|
||||||
|
Value initVal =
|
||||||
|
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
|
if (Dim == 1) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
|
} else if (Dim == 2) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
|
}
|
||||||
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||||
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||||
|
DenseI64ArrayAttr baseDilations;
|
||||||
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||||
|
|
||||||
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloPadding);
|
||||||
|
|
||||||
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||||
|
|
||||||
|
// Add bb argument
|
||||||
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||||
|
block.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
auto *firstArg = block.args_begin();
|
||||||
|
auto secondArg = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
|
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
|
||||||
|
*secondArg);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, int Dim>
|
template <typename AtenOpT, int Dim>
|
||||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
|
@ -375,8 +404,8 @@ public:
|
||||||
auto outShape = outTy.getShape();
|
auto outShape = outTy.getShape();
|
||||||
|
|
||||||
if (inputRank <= Dim) {
|
if (inputRank <= Dim) {
|
||||||
return op.emitError(
|
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
|
||||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
"higher than 1/2/3");
|
||||||
}
|
}
|
||||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||||
bool ceilMode = false;
|
bool ceilMode = false;
|
||||||
|
@ -405,6 +434,10 @@ public:
|
||||||
op, "non-const bool count_include_pad unsupported!");
|
op, "non-const bool count_include_pad unsupported!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (stride.empty()) {
|
||||||
|
stride = kernelSize;
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -425,11 +458,20 @@ public:
|
||||||
if (Dim == 1) {
|
if (Dim == 1) {
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
} else {
|
} else if (Dim == 2) {
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initVal =
|
Value initVal =
|
||||||
|
@ -474,10 +516,17 @@ public:
|
||||||
divisor =
|
divisor =
|
||||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||||
.value();
|
.value();
|
||||||
} else {
|
} else if (Dim == 2) {
|
||||||
divisor = hlo::getConstTensor<int64_t>(
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.value();
|
.value();
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
|
rewriter, op,
|
||||||
|
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
|
||||||
|
.value();
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
}
|
}
|
||||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
|
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||||
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
target.addIllegalOp<AtenOp>(); \
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||||
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
#undef INSERT_ATEN_POOLING_PATTERN
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
|
||||||
context, options);
|
target.addIllegalOp<AtenOp>(); \
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
options)
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
|
||||||
|
#undef INSERT_ATEN_MAXPOOL_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||||
options)
|
options)
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||||
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
|
||||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t collapseStartDim,
|
||||||
|
int64_t collapseEndDim,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
|
|
||||||
|
auto dimSizesInfo =
|
||||||
|
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(dimSizesInfo))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
|
||||||
|
auto dimSizes = *dimSizesInfo;
|
||||||
|
int64_t rank = dimSizes.size();
|
||||||
|
|
||||||
|
collapseStartDim = toPositiveDim(collapseStartDim, rank);
|
||||||
|
collapseEndDim = toPositiveDim(collapseEndDim, rank);
|
||||||
|
|
||||||
|
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
|
auto oldShape = rankTy.getShape();
|
||||||
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
|
|
||||||
|
std::vector<Value> newDimSizes;
|
||||||
|
std::vector<int64_t> newShape;
|
||||||
|
newDimSizes.reserve(newRank);
|
||||||
|
newShape.reserve(newRank);
|
||||||
|
|
||||||
|
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
int64_t collapseShape = 1;
|
||||||
|
|
||||||
|
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
||||||
|
if (k < 0 || k >= rank) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "collapse dimensions must be within the rank of the tensor");
|
||||||
|
}
|
||||||
|
if (collapseShape == ShapedType::kDynamic ||
|
||||||
|
oldShape[k] == ShapedType::kDynamic) {
|
||||||
|
collapseShape = ShapedType::kDynamic;
|
||||||
|
} else {
|
||||||
|
collapseShape *= oldShape[k];
|
||||||
|
}
|
||||||
|
collapseDimSize =
|
||||||
|
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t k = 0; k < collapseStartDim; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
newDimSizes.push_back(collapseDimSize);
|
||||||
|
newShape.push_back(collapseShape);
|
||||||
|
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||||
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||||
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: support splitDim & outerLength to be Value
|
||||||
|
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t splitDim,
|
||||||
|
int64_t outerLength, size_t dimSizeIndexBits) {
|
||||||
|
auto dimSizesInfo =
|
||||||
|
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(dimSizesInfo))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
|
||||||
|
auto dimSizes = *dimSizesInfo;
|
||||||
|
int64_t rank = dimSizes.size();
|
||||||
|
splitDim = toPositiveDim(splitDim, rank);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
|
auto oldShape = rankTy.getShape();
|
||||||
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (splitDim < 0 || splitDim >= rank) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "split dimensions must be within the rank of the tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t newRank = rank + 1;
|
||||||
|
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(intType, outerLength));
|
||||||
|
|
||||||
|
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
||||||
|
loc, dimSizes[splitDim], outerLengthValue);
|
||||||
|
|
||||||
|
int64_t originShape = oldShape[splitDim];
|
||||||
|
int64_t outerShape = outerLength;
|
||||||
|
int64_t innerShape = originShape == ShapedType::kDynamic
|
||||||
|
? ShapedType::kDynamic
|
||||||
|
: originShape / outerLength;
|
||||||
|
|
||||||
|
std::vector<Value> newDimSizes;
|
||||||
|
std::vector<int64_t> newShape;
|
||||||
|
|
||||||
|
newDimSizes.reserve(newRank);
|
||||||
|
newShape.reserve(newRank);
|
||||||
|
|
||||||
|
for (int64_t k = 0; k < splitDim; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
newDimSizes.push_back(outerLengthValue);
|
||||||
|
newShape.push_back(outerShape);
|
||||||
|
newDimSizes.push_back(innerLengthValue);
|
||||||
|
newShape.push_back(innerShape);
|
||||||
|
|
||||||
|
for (int64_t k = splitDim + 1; k < rank; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||||
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||||
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType) {
|
TensorType outType) {
|
||||||
|
|
|
@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only constant end is currently supported");
|
op, "only constant end is currently supported");
|
||||||
|
|
||||||
start = toPositiveDim(start, rank);
|
auto collapseTensorInfo = hlo::collapseTensor(
|
||||||
end = toPositiveDim(end, rank);
|
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
|
||||||
SmallVector<int64_t, 4> dims;
|
if (failed(collapseTensorInfo))
|
||||||
dims.reserve(rank);
|
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
|
||||||
for (int r = 0; r < start; ++r)
|
|
||||||
dims.push_back(r);
|
|
||||||
int64_t collapsedDimSize = 1;
|
|
||||||
for (int r = start; r <= end; ++r) {
|
|
||||||
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "the size of the dimension being collapsed is can't be unknown");
|
|
||||||
collapsedDimSize *= selfType.getShape()[r];
|
|
||||||
}
|
|
||||||
dims.push_back(collapsedDimSize);
|
|
||||||
for (int r = end + 1; r < rank; ++r)
|
|
||||||
dims.push_back(r);
|
|
||||||
|
|
||||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
rewriter.replaceOp(op, *collapseTensorInfo);
|
||||||
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
return success();
|
||||||
if (failed(newDimSizesInfo))
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||||
|
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType) {
|
||||||
|
return op.emitError("only tensor types are currently supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rank = selfType.getRank();
|
||||||
|
if (rank == 0)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "the rank of tensor must be greater than 0");
|
||||||
auto newDimSizes = *newDimSizesInfo;
|
|
||||||
auto stablehloShape =
|
int64_t dim, outerLength;
|
||||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
return rewriter.notifyMatchFailure(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
op, "only constant dim is currently supported");
|
||||||
stablehloShape);
|
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only constant outerLength is currently supported");
|
||||||
|
|
||||||
|
auto splitTensorInfo = hlo::splitTensor(
|
||||||
|
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(splitTensorInfo))
|
||||||
|
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, *splitTensorInfo);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||||
|
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
||||||
|
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
||||||
|
auto dtypeElemType = dtypeComplex.getElementType();
|
||||||
|
|
||||||
|
// Extract the real and imaginary parts of the scalar.
|
||||||
|
// Cast them to the target element type, and create a new complex
|
||||||
|
// value with the target complex type.
|
||||||
|
Value realVal = b.create<complex::ReOp>(loc, scalar);
|
||||||
|
Value imgVal = b.create<complex::ImOp>(loc, scalar);
|
||||||
|
|
||||||
|
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
|
||||||
|
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
|
||||||
|
|
||||||
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||||
|
}
|
||||||
|
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
||||||
|
<< scalarType << "(scalar type) -> " << dtype
|
||||||
|
<< "(dtype)";
|
||||||
|
}
|
||||||
|
|
||||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
|
||||||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||||
bool hasTensorCastOperand =
|
bool hasTensorCastOperand =
|
||||||
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||||
if (opOperand->get().isa<BlockArgument>())
|
if (isa<BlockArgument>(opOperand->get()))
|
||||||
return false;
|
return false;
|
||||||
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||||
return castOp && canFoldIntoConsumerOp(castOp);
|
return castOp && canFoldIntoConsumerOp(castOp);
|
||||||
|
|
|
@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||||
if (!inputDtype || !inputDtype.isInteger(64))
|
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||||
|
@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
if (inputDtype.isInteger(64)) {
|
||||||
|
auto val = valueTensorLiteralOp.getValue()
|
||||||
|
.cast<DenseIntElementsAttr>()
|
||||||
.getSplatValue<int64_t>();
|
.getSplatValue<int64_t>();
|
||||||
return rewriter.create<Torch::ConstantIntOp>(
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
} else {
|
||||||
|
auto val = valueTensorLiteralOp.getValue()
|
||||||
|
.cast<DenseIntElementsAttr>()
|
||||||
|
.getSplatValue<bool>();
|
||||||
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
}
|
||||||
} else if (auto primNumToTensorScalarOp =
|
} else if (auto primNumToTensorScalarOp =
|
||||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||||
return primNumToTensorScalarOp.getA();
|
return primNumToTensorScalarOp.getA();
|
||||||
|
@ -2385,6 +2394,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Aten__Contains__StrListOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
|
||||||
|
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
|
||||||
|
if (!item)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||||
|
if (isListPotentiallyMutated(listConstruct))
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
llvm::SmallVector<std::string> strs;
|
||||||
|
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
|
||||||
|
for (const auto &str : strs) {
|
||||||
|
if (item.getValue().str() == str)
|
||||||
|
return getI1IntegerAttr(getContext(), true);
|
||||||
|
}
|
||||||
|
return getI1IntegerAttr(getContext(), false);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenLtIntOp
|
// AtenLtIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -4682,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrimsConvertElementTypeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto inputType = cast<BaseTensorType>(getA().getType());
|
||||||
|
auto outputType = cast<BaseTensorType>(getResult().getType());
|
||||||
|
if (inputType != outputType)
|
||||||
|
return nullptr;
|
||||||
|
if (!inputType.hasDtype() || !outputType.hasDtype())
|
||||||
|
return nullptr;
|
||||||
|
if (inputType.getDtype() != outputType.getDtype())
|
||||||
|
return nullptr;
|
||||||
|
return getA();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenMaxPool2dWithIndicesOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
|
||||||
|
if (!op.getResult1().use_empty()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "result1 of MaxPool2dWithIndices should be unused");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
|
||||||
|
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
||||||
|
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
||||||
|
op.getCeilMode());
|
||||||
|
|
||||||
|
op.getResult0().replaceAllUsesWith(result);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenLinalgCrossOp
|
// AtenLinalgCrossOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -7841,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" return %arg2 : !torch.list<int>\n"
|
" return %arg2 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %int-1 = torch.constant.int -1\n"
|
" %int-1 = torch.constant.int -1\n"
|
||||||
" %int-2 = torch.constant.int -2\n"
|
" %int-2 = torch.constant.int -2\n"
|
||||||
" %int-3 = torch.constant.int -3\n"
|
" %int-3 = torch.constant.int -3\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
|
||||||
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
|
||||||
" %int1 = torch.constant.int 1\n"
|
" %int1 = torch.constant.int 1\n"
|
||||||
" %int0 = torch.constant.int 0\n"
|
" %int0 = torch.constant.int 0\n"
|
||||||
" %int2 = torch.constant.int 2\n"
|
" %int2 = torch.constant.int 2\n"
|
||||||
|
@ -7936,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %23 : !torch.list<int>\n"
|
" return %23 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -10480,6 +10488,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
|
|
@ -1059,44 +1059,44 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
int64_t n;
|
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||||
|
|
||||||
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"unimplemented: n must be constant");
|
|
||||||
int64_t m;
|
|
||||||
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"unimplemented: m must be constant");
|
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
||||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
if (!outType.hasDtype()) {
|
if (!outType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
if (n < 0) {
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
|
||||||
}
|
|
||||||
if (m < 0) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto context = op.getContext();
|
auto context = op.getContext();
|
||||||
auto int64Dtype = getDtypeIntValueForType(
|
auto int64Dtype = getDtypeIntValueForType(
|
||||||
rewriter, loc,
|
rewriter, loc,
|
||||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
|
||||||
|
int64_t n = kUnknownSize;
|
||||||
|
int64_t m = kUnknownSize;
|
||||||
|
// prioritize getting shape from output shape
|
||||||
|
if (outType.hasSizes() && outType.getSizes().size() == 2) {
|
||||||
|
n = outType.getSizes().front();
|
||||||
|
m = outType.getSizes().back();
|
||||||
|
}
|
||||||
|
// if output shape is not available, try to get shape from input
|
||||||
|
if (n == kUnknownSize)
|
||||||
|
matchPattern(op.getN(), m_TorchConstantInt(&n));
|
||||||
|
if (m == kUnknownSize)
|
||||||
|
matchPattern(op.getM(), m_TorchConstantInt(&m));
|
||||||
|
|
||||||
|
// prepare two unsqueezed ranges that are equal on and only on the diagonal
|
||||||
|
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
|
||||||
|
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
|
||||||
Value rangeN = rewriter.create<AtenArangeOp>(
|
Value rangeN = rewriter.create<AtenArangeOp>(
|
||||||
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||||
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
||||||
|
|
||||||
auto arangeType1 =
|
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
|
||||||
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
|
||||||
Value rangeM = rewriter.create<AtenArangeOp>(
|
Value rangeM = rewriter.create<AtenArangeOp>(
|
||||||
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||||
/*device=*/none, /*pin_memory=*/none);
|
/*device=*/none, /*pin_memory=*/none);
|
||||||
|
|
||||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -1109,7 +1109,6 @@ public:
|
||||||
}
|
}
|
||||||
Value unsqzRangeN = *unsqzTensorInfo;
|
Value unsqzRangeN = *unsqzTensorInfo;
|
||||||
|
|
||||||
// compare unsqueezed input with boundaries
|
|
||||||
auto eqType = ValueTensorType::get(
|
auto eqType = ValueTensorType::get(
|
||||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||||
IntegerType::get(context, 1));
|
IntegerType::get(context, 1));
|
||||||
|
@ -2415,6 +2414,50 @@ public:
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenCeluOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = op.getSelf();
|
||||||
|
Value alpha = op.getAlpha();
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value constantZero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value constantOne =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
|
||||||
|
// positiveOutput = max(0,x)
|
||||||
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
||||||
|
Value positiveOutput =
|
||||||
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
||||||
|
|
||||||
|
// negativeOutput = min(0,alpha∗(exp(x/alpha)−1))
|
||||||
|
Value scaledInput =
|
||||||
|
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
|
||||||
|
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
|
||||||
|
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
|
||||||
|
constantOne, constantOne);
|
||||||
|
Value scaledExpXM1 =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
|
||||||
|
Value negativeOutput =
|
||||||
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
|
||||||
|
Value celuOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
loc, resType, positiveOutput, negativeOutput, constantOne);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, celuOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
|
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -7705,6 +7748,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||||
|
|
|
@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
||||||
target.addIllegalOp<AtenPadOp>();
|
target.addIllegalOp<AtenPadOp>();
|
||||||
target.addIllegalOp<AtenPreluOp>();
|
target.addIllegalOp<AtenPreluOp>();
|
||||||
|
target.addIllegalOp<AtenCeluOp>();
|
||||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||||
target.addIllegalOp<AtenToDeviceOp>();
|
target.addIllegalOp<AtenToDeviceOp>();
|
||||||
target.addIllegalOp<AtenToPrimDeviceOp>();
|
target.addIllegalOp<AtenToPrimDeviceOp>();
|
||||||
|
|
|
@ -272,6 +272,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"ConvTranspose2DQInt8_basic",
|
||||||
# Dynamo not supporting conv_tbc
|
# Dynamo not supporting conv_tbc
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
|
@ -372,6 +373,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
"ConvTranspose2DQInt8_basic",
|
||||||
"ConvolutionBackwardModule2DPadded_basic",
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
|
@ -544,6 +546,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
"ConvTranspose2DQInt8_basic",
|
||||||
"ConvolutionBackwardModule2DPadded_basic",
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
|
@ -572,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseLogitModule_basic",
|
"ElementwiseLogitModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
|
@ -678,11 +682,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
"NumelZeroRankModule_basic",
|
"NumelZeroRankModule_basic",
|
||||||
"PixelShuffleModuleFullDynamic_basic",
|
|
||||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
|
||||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
|
||||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
|
||||||
"PowIntFloatModule_basic",
|
"PowIntFloatModule_basic",
|
||||||
"PrimMaxIntModule_basic",
|
"PrimMaxIntModule_basic",
|
||||||
"PrimMinIntDynamicModule_basic",
|
"PrimMinIntDynamicModule_basic",
|
||||||
|
@ -951,6 +950,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||||
"ElementwiseCeilModule_basic",
|
"ElementwiseCeilModule_basic",
|
||||||
|
"ElementwiseCeluStaticModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
"ElementwiseClampMinModule_basic",
|
"ElementwiseClampMinModule_basic",
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
|
@ -1079,6 +1079,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"Matmul_vecmat",
|
"Matmul_vecmat",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
|
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||||
|
"MaxPool3dStaticModule_basic",
|
||||||
|
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||||
"MeanDimAllReduceModule_basic",
|
"MeanDimAllReduceModule_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"MeanDimNoneDimModule_basic",
|
||||||
|
@ -1156,6 +1159,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"Permute0RankModule_basic",
|
"Permute0RankModule_basic",
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||||
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
"PowIntFloatModule_basic",
|
"PowIntFloatModule_basic",
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
"PrimMaxIntModule_basic",
|
"PrimMaxIntModule_basic",
|
||||||
|
@ -1239,6 +1244,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"SliceWholeTensorModule_basic",
|
"SliceWholeTensorModule_basic",
|
||||||
"SortIntListReverse_basic",
|
"SortIntListReverse_basic",
|
||||||
"SortIntList_basic",
|
"SortIntList_basic",
|
||||||
|
"SplitDimStaticModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
"SplitTensorGetItem_Module_basic",
|
||||||
"SplitTensorLastSmallerModule_basic",
|
"SplitTensorLastSmallerModule_basic",
|
||||||
"SplitTensorListUnpackModule_basic",
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
@ -1571,6 +1577,8 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseBitwiseXorModule_basic",
|
"ElementwiseBitwiseXorModule_basic",
|
||||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||||
"ElementwiseCeilModule_basic",
|
"ElementwiseCeilModule_basic",
|
||||||
|
"ElementwiseCeluModule_basic",
|
||||||
|
"ElementwiseCeluStaticModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
"ElementwiseClampMinModule_basic",
|
"ElementwiseClampMinModule_basic",
|
||||||
"ElementwiseClampModule_basic",
|
"ElementwiseClampModule_basic",
|
||||||
|
@ -1916,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
|
||||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
|
||||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
|
||||||
"MaxPool2dStaticModule_basic",
|
|
||||||
"ResNet18StaticModule_basic",
|
|
||||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||||
|
@ -2096,6 +2099,7 @@ LTC_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
|
"ConvTranspose2DQInt8_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_XFAIL_SET = {
|
ONNX_XFAIL_SET = {
|
||||||
|
@ -2121,7 +2125,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||||
"ElementwiseLog10IntModule_basic",
|
"ElementwiseLog10IntModule_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
"ElementwiseSeluModule_basic",
|
|
||||||
"FlipModuleStaticShape_basic",
|
"FlipModuleStaticShape_basic",
|
||||||
"FlipNegativeIndexModule_basic",
|
"FlipNegativeIndexModule_basic",
|
||||||
"HardsigmoidModule_basic",
|
"HardsigmoidModule_basic",
|
||||||
|
@ -2251,6 +2254,7 @@ ONNX_XFAIL_SET = {
|
||||||
"Conv2dWithPaddingModule_basic",
|
"Conv2dWithPaddingModule_basic",
|
||||||
"Conv3dModule_basic",
|
"Conv3dModule_basic",
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
"ConvTranspose2DQInt8_basic",
|
||||||
"Conv_Transpose2dModule_basic",
|
"Conv_Transpose2dModule_basic",
|
||||||
"Convolution2DModule_basic",
|
"Convolution2DModule_basic",
|
||||||
"Convolution2DStridedModule_basic",
|
"Convolution2DStridedModule_basic",
|
||||||
|
@ -2306,6 +2310,7 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseFmodTensor_Int_basic",
|
"ElementwiseFmodTensor_Int_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
"ElementwiseOrTensorModule_basic",
|
"ElementwiseOrTensorModule_basic",
|
||||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
|
@ -2554,16 +2559,12 @@ ONNX_XFAIL_SET = {
|
||||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
"_SoftmaxModule_basic",
|
"_SoftmaxModule_basic",
|
||||||
|
# Failure - onnx_import
|
||||||
# Failure - onnx_lowering: onnx.AveragePool
|
# Failure - onnx_lowering: onnx.AveragePool
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
# Failure - onnx_lowering: onnx.If
|
# these diagonal modules are currently failing due to dynamic shape.
|
||||||
"DiagonalModule_basic",
|
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
|
||||||
"DiagonalModule_nonsquare",
|
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
|
||||||
"DiagonalModule_transposed",
|
|
||||||
"DiagonalModule_with_dims",
|
|
||||||
"DiagonalModule_with_dims_and_offset",
|
|
||||||
"DiagonalModule_with_negative_dims",
|
|
||||||
"DiagonalModule_with_offset",
|
|
||||||
"TileBigDimsSizeModule_basic",
|
"TileBigDimsSizeModule_basic",
|
||||||
"TileSmallDimsSizeModule_basic",
|
"TileSmallDimsSizeModule_basic",
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
|
@ -2634,8 +2635,6 @@ ONNX_XFAIL_SET = {
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
"DropoutTrainModule_basic",
|
|
||||||
"DropoutTrainStaticShapeModule_basic",
|
|
||||||
"ElementwiseAcosIntModule_basic",
|
"ElementwiseAcosIntModule_basic",
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
|
|
|
@ -3,6 +3,7 @@ from torch_mlir import torchscript
|
||||||
|
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
# Wrap the bert model to avoid multiple returns problem
|
# Wrap the bert model to avoid multiple returns problem
|
||||||
class BertTinyWrapper(torch.nn.Module):
|
class BertTinyWrapper(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
@ -257,9 +257,9 @@ class _FXGraphImporter:
|
||||||
# FakeTensor's in case of a tuple return with multiple elements.
|
# FakeTensor's in case of a tuple return with multiple elements.
|
||||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||||
self._module = ir.Module.create(ir.Location.unknown())
|
self._module = ir.Module.create(ir.Location.unknown())
|
||||||
self._module.operation.attributes[
|
self._module.operation.attributes["torch.debug_module_name"] = (
|
||||||
"torch.debug_module_name"
|
ir.StringAttr.get(func_name)
|
||||||
] = ir.StringAttr.get(func_name)
|
)
|
||||||
function_type = _extract_function_type_from_graph(g)
|
function_type = _extract_function_type_from_graph(g)
|
||||||
func = func_dialect.FuncOp(
|
func = func_dialect.FuncOp(
|
||||||
func_name,
|
func_name,
|
||||||
|
|
|
@ -526,6 +526,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
|
||||||
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
|
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -958,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
||||||
|
|
||||||
# TODO: This should be upstreamed.
|
# TODO: This should be upstreamed.
|
||||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||||
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
|
||||||
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
|
||||||
kL = kernel_size[0]
|
kL = kernel_size[0]
|
||||||
|
|
||||||
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
|
||||||
dL = kL if len(stride) == 0 else stride[0]
|
dL = kL if len(stride) == 0 else stride[0]
|
||||||
|
|
||||||
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
assert len(padding) == 1, "pool1d: padding must be a single int"
|
||||||
padL = padding[0]
|
padL = padding[0]
|
||||||
|
|
||||||
dilationL = 1
|
dilationL = 1
|
||||||
|
@ -1001,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||||
|
|
||||||
|
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
|
||||||
|
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||||
|
|
||||||
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
||||||
return adaptive_avg_pool1d(self, output_size)
|
return adaptive_avg_pool1d(self, output_size)
|
||||||
|
@ -2652,6 +2658,11 @@ def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu
|
||||||
assert self_dtype == weight_dtype
|
assert self_dtype == weight_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.))
|
||||||
|
def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||||
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
(ns, unqual + "_", overload if not is_functional_op else "")
|
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||||
),
|
),
|
||||||
emitter_td,
|
emitter_td,
|
||||||
traits=["IsTrailingUnderscoreInplaceVariant"]
|
traits=(
|
||||||
if not is_functional_op
|
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
|
||||||
else [],
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::real : (Tensor) -> (Tensor)")
|
emit("aten::real : (Tensor) -> (Tensor)")
|
||||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||||
|
@ -590,9 +591,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
|
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
)
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
|
@ -973,6 +976,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::format : (...) -> (str)")
|
emit("aten::format : (...) -> (str)")
|
||||||
emit("aten::join : (str, str[]) -> (str)")
|
emit("aten::join : (str, str[]) -> (str)")
|
||||||
emit("aten::warn : (str, int) -> ()")
|
emit("aten::warn : (str, int) -> ()")
|
||||||
|
emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True)
|
||||||
|
|
||||||
# Type conversion ops.
|
# Type conversion ops.
|
||||||
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
||||||
|
@ -1101,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
# `prims::` namespace.
|
# `prims::` namespace.
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
|
||||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||||
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
|
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
|
||||||
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
||||||
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
|
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
|
||||||
examples = []
|
examples = []
|
||||||
input_names = []
|
input_names = []
|
||||||
dynamic_tensors = {}
|
dynamic_tensors = {}
|
||||||
for (index, arg) in enumerate(inputs):
|
for index, arg in enumerate(inputs):
|
||||||
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||||
|
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
|
||||||
input_names.append(input_name)
|
input_names.append(input_name)
|
||||||
|
|
||||||
dynamic_dims = {}
|
dynamic_dims = {}
|
||||||
for (dimindex, dim) in enumerate(arg.shape):
|
for dimindex, dim in enumerate(arg.shape):
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||||
|
|
||||||
|
|
|
@ -101,11 +101,13 @@ class RefBackendInvoker:
|
||||||
def consume_return_funcs(*args):
|
def consume_return_funcs(*args):
|
||||||
self.result = tuple(
|
self.result = tuple(
|
||||||
[
|
[
|
||||||
|
(
|
||||||
arg
|
arg
|
||||||
if type in elemental_type_to_ctype
|
if type in elemental_type_to_ctype
|
||||||
else unranked_memref_to_numpy(
|
else unranked_memref_to_numpy(
|
||||||
arg, memref_type_to_np_dtype[type]
|
arg, memref_type_to_np_dtype[type]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for arg, type in zip(args, ret_types)
|
for arg, type in zip(args, ret_types)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -178,6 +180,7 @@ LOWERING_PIPELINE = (
|
||||||
"func.func(tm-tensor-to-loops)",
|
"func.func(tm-tensor-to-loops)",
|
||||||
"func.func(refback-munge-memref-copy)",
|
"func.func(refback-munge-memref-copy)",
|
||||||
"func.func(convert-linalg-to-loops)",
|
"func.func(convert-linalg-to-loops)",
|
||||||
|
"func.func(expand-realloc)",
|
||||||
"func.func(lower-affine)",
|
"func.func(lower-affine)",
|
||||||
"convert-scf-to-cf",
|
"convert-scf-to-cf",
|
||||||
"func.func(refback-expand-ops-for-llvm)",
|
"func.func(refback-expand-ops-for-llvm)",
|
||||||
|
@ -191,6 +194,7 @@ LOWERING_PIPELINE = (
|
||||||
"convert-bufferization-to-memref",
|
"convert-bufferization-to-memref",
|
||||||
"finalize-memref-to-llvm",
|
"finalize-memref-to-llvm",
|
||||||
"func.func(convert-arith-to-llvm)",
|
"func.func(convert-arith-to-llvm)",
|
||||||
|
"convert-vector-to-llvm",
|
||||||
"convert-func-to-llvm",
|
"convert-func-to-llvm",
|
||||||
"convert-cf-to-llvm",
|
"convert-cf-to-llvm",
|
||||||
"convert-complex-to-llvm",
|
"convert-complex-to-llvm",
|
||||||
|
|
|
@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
||||||
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||||
bias = torch.rand(3)
|
bias = torch.rand(3)
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
N = 10
|
||||||
|
Cin = 5
|
||||||
|
Cout = 7
|
||||||
|
Hin = 10
|
||||||
|
Win = 8
|
||||||
|
Hker = 3
|
||||||
|
Wker = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.int8, True),
|
||||||
|
([-1, -1, -1, -1], torch.int8, True),
|
||||||
|
([-1], torch.float, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, input, weight, bias):
|
||||||
|
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
|
||||||
|
qinput = torch.dequantize(qinput)
|
||||||
|
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
|
||||||
|
qweight = torch.dequantize(qweight)
|
||||||
|
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
|
||||||
|
qbias = torch.dequantize(qbias)
|
||||||
|
qz = torch.ops.aten.convolution(
|
||||||
|
qinput,
|
||||||
|
qweight,
|
||||||
|
bias=qbias,
|
||||||
|
stride=[2, 1],
|
||||||
|
padding=[1, 1],
|
||||||
|
dilation=[1, 1],
|
||||||
|
transposed=True,
|
||||||
|
output_padding=[0, 0],
|
||||||
|
groups=1,
|
||||||
|
)
|
||||||
|
return qz
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
||||||
|
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
||||||
|
torch.rand(Cout),
|
||||||
|
)
|
||||||
|
|
|
@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalWithStaticShapeModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Diagonal with static shape. The other diagonal modules are failing in onnx
|
||||||
|
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
|
||||||
|
when the shape is static.
|
||||||
|
|
||||||
|
Please remove this module and associated test once the issue is fixed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([5, 9], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.diagonal(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
|
||||||
|
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 9))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DiagonalTransposedModule(torch.nn.Module):
|
class DiagonalTransposedModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
||||||
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
|
||||||
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -1016,6 +1014,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseCeluModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.celu(x, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseCeluModule())
|
||||||
|
def ElementwiseCeluModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseCeluStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([5, 3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.celu(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule())
|
||||||
|
def ElementwiseCeluStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseGeluModule(torch.nn.Module):
|
class ElementwiseGeluModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1795,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# torch.complex32 is not supported by the refbackend.
|
||||||
|
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.complex64, True),
|
||||||
|
([-1], torch.complex128, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.mul(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
|
||||||
|
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(4, high=10).type(torch.complex64),
|
||||||
|
tu.randint(4, high=10).type(torch.complex128),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseMishModule(torch.nn.Module):
|
class ElementwiseMishModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||||
class SliceScatterModule(torch.nn.Module):
|
class SliceScatterModule(torch.nn.Module):
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -18,6 +18,7 @@ mb = ModuleBuilder()
|
||||||
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
||||||
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.add3
|
# CHECK-LABEL: func.func @__torch__.add3
|
||||||
# Note that line-level debug information for parts unannotated in the Torch
|
# Note that line-level debug information for parts unannotated in the Torch
|
||||||
# graph are ascribed to the first op that carries source information. Presently
|
# graph are ascribed to the first op that carries source information. Presently
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: @__torch__.f
|
# CHECK-LABEL: @__torch__.f
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.optional_return(
|
# CHECK-LABEL: func.func @__torch__.optional_return(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
||||||
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
||||||
|
|
|
@ -13,6 +13,7 @@ mb = ModuleBuilder()
|
||||||
# else branch and making all defined values optional, so no special handling
|
# else branch and making all defined values optional, so no special handling
|
||||||
# is needed.
|
# is needed.
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: @__torch__.prim_If(
|
# CHECK-LABEL: @__torch__.prim_If(
|
||||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||||
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
||||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
||||||
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
||||||
|
|
|
@ -15,6 +15,7 @@ import typing
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||||
|
|
|
@ -13,6 +13,7 @@ from utils import create_script_function
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func.func @__torch__.tuple(
|
# CHECK-LABEL: func.func @__torch__.tuple(
|
||||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: @__torch__.returns_bool
|
# CHECK: @__torch__.returns_bool
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK: @__torch__.returns_none
|
# CHECK: @__torch__.returns_none
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch._C import CompilationUnit
|
||||||
|
|
||||||
# RUN: %PYTHON %s
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
|
|
||||||
# Import TorchScript IR string as ScriptFunction.
|
# Import TorchScript IR string as ScriptFunction.
|
||||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||||
cu = CompilationUnit()
|
cu = CompilationUnit()
|
||||||
|
|
|
@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
|
||||||
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
||||||
|
|
||||||
if _IS_TORCH_2_1_OR_EARLIER:
|
if _IS_TORCH_2_1_OR_EARLIER:
|
||||||
SYMBOLIC_TORCH_OPS = {
|
|
||||||
torch.ops.aten.sym_size,
|
|
||||||
torch.ops.aten.sym_stride,
|
|
||||||
torch.ops.aten.sym_numel,
|
|
||||||
}
|
|
||||||
|
|
||||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||||
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
||||||
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
||||||
|
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
|
||||||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
SYMBOLIC_TORCH_OPS = {
|
|
||||||
torch.ops.aten.sym_size.int,
|
|
||||||
torch.ops.aten.sym_stride.int,
|
|
||||||
torch.ops.aten.sym_numel.default,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||||
|
else:
|
||||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||||
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
||||||
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
||||||
|
@ -264,6 +254,8 @@ else:
|
||||||
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SparsityMeta:
|
class SparsityMeta:
|
||||||
|
@ -1857,8 +1849,7 @@ def _emit_operation(
|
||||||
|
|
||||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||||
# may have a different meaning.
|
# may have a different meaning.
|
||||||
class EmptyType:
|
class EmptyType: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
Empty = EmptyType()
|
Empty = EmptyType()
|
||||||
|
|
|
@ -156,8 +156,7 @@ class GraphInfo:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class OnnxImportError(Exception):
|
class OnnxImportError(Exception): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class NodeImporter:
|
class NodeImporter:
|
||||||
|
@ -235,22 +234,22 @@ class NodeImporter:
|
||||||
else:
|
else:
|
||||||
default_opset_version = opset_import.version
|
default_opset_version = opset_import.version
|
||||||
if default_opset_version:
|
if default_opset_version:
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.opset_version"] = (
|
||||||
"torch.onnx_meta.opset_version"
|
IntegerAttr.get(i64_type, default_opset_version)
|
||||||
] = IntegerAttr.get(i64_type, default_opset_version)
|
)
|
||||||
if opset_versions:
|
if opset_versions:
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.opset_versions"] = (
|
||||||
"torch.onnx_meta.opset_versions"
|
DictAttr.get(opset_versions)
|
||||||
] = DictAttr.get(opset_versions)
|
)
|
||||||
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
||||||
IntegerType.get_signed(64), m.ir_version
|
IntegerType.get_signed(64), m.ir_version
|
||||||
)
|
)
|
||||||
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
||||||
m.producer_name
|
m.producer_name
|
||||||
)
|
)
|
||||||
container_op.attributes[
|
container_op.attributes["torch.onnx_meta.producer_version"] = (
|
||||||
"torch.onnx_meta.producer_version"
|
StringAttr.get(m.producer_version)
|
||||||
] = StringAttr.get(m.producer_version)
|
)
|
||||||
|
|
||||||
def import_all(self, func=True):
|
def import_all(self, func=True):
|
||||||
"""Imports all nodes topologically."""
|
"""Imports all nodes topologically."""
|
||||||
|
@ -348,8 +347,14 @@ class NodeImporter:
|
||||||
continue
|
continue
|
||||||
elif handler is False:
|
elif handler is False:
|
||||||
# Active error.
|
# Active error.
|
||||||
|
# try matching attribute type ID to name for a more descriptive error message
|
||||||
|
try:
|
||||||
|
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
|
||||||
|
except ValueError:
|
||||||
|
attr_type_name = "UNKNOWN"
|
||||||
raise OnnxImportError(
|
raise OnnxImportError(
|
||||||
f"ONNX importer does not support generic node attribute type {attr_type}. "
|
f"ONNX importer does not support generic node attribute type {attr_type_name} "
|
||||||
|
f"with ID {attr_type}. "
|
||||||
f"This likely means that this is a special node which requires specific "
|
f"This likely means that this is a special node which requires specific "
|
||||||
f"handling in the importer: {onnx_attr}"
|
f"handling in the importer: {onnx_attr}"
|
||||||
)
|
)
|
||||||
|
@ -658,9 +663,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
|
||||||
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
||||||
IntegerAttr.get(
|
IntegerAttr.get(
|
||||||
IntegerType.get_signed(64),
|
IntegerType.get_signed(64),
|
||||||
|
(
|
||||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||||
if tp.HasField("raw_data")
|
if tp.HasField("raw_data")
|
||||||
else tp.int64_data[0],
|
else tp.int64_data[0]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
||||||
|
@ -703,7 +710,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
||||||
),
|
),
|
||||||
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
||||||
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
||||||
)
|
),
|
||||||
# Intentionally unsupported: STRING
|
# Intentionally unsupported: STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from typing import Optional, Union, Dict, Tuple, Any
|
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def export_and_import(
|
||||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||||
experimental_support_mutation: bool = False,
|
experimental_support_mutation: bool = False,
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
decomposition_table: Optional[list] = None,
|
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||||
func_name: str = "main",
|
func_name: str = "main",
|
||||||
enable_graph_printing: bool = False,
|
enable_graph_printing: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7
|
34ade3521ca41f20af3469bba276c2b0499c3892
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
--pre
|
--pre
|
||||||
torch==2.4.0.dev20240422
|
torch==2.4.0.dev20240428
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_ifop_basic
|
||||||
|
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
|
||||||
|
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: } else {
|
||||||
|
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
|
||||||
|
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
|
||||||
|
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||||
|
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||||
|
}, {
|
||||||
|
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||||
|
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
|
@ -1996,3 +1996,160 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten
|
||||||
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
|
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
|
||||||
return %0 : !torch.vtensor<[3,?],f32>
|
return %0 : !torch.vtensor<[3,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
|
||||||
|
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
|
||||||
|
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
|
||||||
|
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_blackmanwindow
|
||||||
|
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
|
||||||
|
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
|
||||||
|
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_hannwindow
|
||||||
|
func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||||
|
|
||||||
|
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_hannwindow_symmetric
|
||||||
|
func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||||
|
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||||
|
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||||
|
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||||
|
|
||||||
|
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||||
|
return %0 : !torch.vtensor<[10],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -860,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
||||||
|
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
|
||||||
|
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
|
||||||
|
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||||
|
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
|
||||||
|
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
||||||
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
@ -942,41 +993,24 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example
|
||||||
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
|
// CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
|
||||||
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
|
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
|
||||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
|
// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
|
||||||
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
|
||||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
|
||||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
|
||||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
|
||||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
|
||||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
|
|
||||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
|
|
||||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
|
||||||
return %0 : !torch.vtensor<[3,2,1],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
|
|
||||||
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
@ -984,15 +1018,65 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2
|
||||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
|
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
|
||||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
|
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
|
||||||
return %0 : !torch.vtensor<[3,2],f32>
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero
|
||||||
|
func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32>
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,0,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example
|
||||||
|
func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example
|
||||||
|
func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
|
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
|
||||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
|
|
@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
|
||||||
// CHECK-LABEL: func.func @torch.aten.reciprocal(
|
// CHECK-LABEL: func.func @torch.aten.reciprocal(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
||||||
|
@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
|
||||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
|
@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
||||||
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
|
@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||||
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
|
@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
|
||||||
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
|
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||||
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
|
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||||
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[STR:.*]] = torch.constant.str "none"
|
// CHECK: %[[STR:.*]] = torch.constant.str "none"
|
||||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||||
|
@ -487,7 +487,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
|
||||||
// CHECK-LABEL: func.func @torch.aten.relu(
|
// CHECK-LABEL: func.func @torch.aten.relu(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
||||||
|
@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
||||||
|
|
|
@ -14,11 +14,11 @@
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||||
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
|
||||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
|
||||||
|
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
|
||||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||||
// CHECK: })
|
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||||
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
||||||
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
|
||||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
||||||
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||||
|
@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||||
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||||
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
// CHECK: }) : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||||
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||||
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||||
|
@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
|
||||||
|
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||||
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||||
|
@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
|
||||||
|
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||||
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]])
|
||||||
|
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||||
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
||||||
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||||
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
|
@ -22,10 +22,10 @@
|
||||||
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
|
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||||
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
|
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||||
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
|
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
|
||||||
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({
|
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
|
||||||
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
|
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
|
||||||
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
|
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
|
||||||
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false} : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
// CHECK: }) : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
|
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
|
||||||
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||||
|
|
|
@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
|
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
|
||||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
|
// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
// CHECK-NEXT: return %[[TRUE]] : !torch.bool
|
||||||
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||||
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
|
@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: return %[[FALSE]] : !torch.bool
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
%str4 = torch.constant.str "4"
|
%str4 = torch.constant.str "4"
|
||||||
%str5 = torch.constant.str "5"
|
%str5 = torch.constant.str "5"
|
||||||
|
@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
||||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
|
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
// CHECK-NEXT: return %[[FALSE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||||
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||||
return %0 : !torch.bool
|
return %0 : !torch.bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: return %[[TRUE]] : !torch.bool
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||||
%str4 = torch.constant.str "4"
|
%str4 = torch.constant.str "4"
|
||||||
%str4_0 = torch.constant.str "4"
|
%str4_0 = torch.constant.str "4"
|
||||||
|
@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int {
|
||||||
return %2 : !torch.int
|
return %2 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||||
|
%str = torch.constant.str "c"
|
||||||
|
%str_0 = torch.constant.str "b"
|
||||||
|
%str_1 = torch.constant.str "a"
|
||||||
|
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||||
|
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||||
|
%str = torch.constant.str "aa"
|
||||||
|
%str_0 = torch.constant.str "aa"
|
||||||
|
%str_1 = torch.constant.str "ccc"
|
||||||
|
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||||
|
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.__not__
|
// CHECK-LABEL: func.func @torch.aten.__not__
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: return %[[TRUE]] : !torch.bool
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
@ -2950,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
|
||||||
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
|
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
|
||||||
return %result : !torch.vtensor<[4], f32>
|
return %result : !torch.vtensor<[4], f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
|
||||||
|
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||||
|
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
|
||||||
|
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
|
||||||
|
return %0 : !torch.vtensor<[64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
|
||||||
|
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||||
|
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||||
|
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
|
||||||
|
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||||
|
return %0 : !torch.vtensor<[64],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
|
||||||
|
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||||
|
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
|
||||||
|
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
|
||||||
|
return %result0 : !torch.vtensor<[10,64,56,56],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
||||||
print(m)
|
print(m)
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
|
||||||
|
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||||
|
def test_broadcast_with_dynamic_shapes():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.broadcast_to(x, (y.shape[0], -1))
|
||||||
|
|
||||||
|
# Sample inputs
|
||||||
|
x = torch.randn(1, 2)
|
||||||
|
y = torch.randn(10)
|
||||||
|
|
||||||
|
dim_0 = Dim("dim_0")
|
||||||
|
dynamic_shapes = {
|
||||||
|
"x": {},
|
||||||
|
"y": {0: dim_0},
|
||||||
|
}
|
||||||
|
|
||||||
|
m = fx.export_and_import(
|
||||||
|
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
|
||||||
|
)
|
||||||
|
print(m)
|
||||||
|
|
||||||
|
|
||||||
@make_boxed_compiler
|
@make_boxed_compiler
|
||||||
def fx_import_aot_autograd_backend(
|
def fx_import_aot_autograd_backend(
|
||||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||||
|
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
|
||||||
|
|
||||||
@run
|
@run
|
||||||
# CHECK-LABEL: test_stateless_fx_import
|
# CHECK-LABEL: test_stateless_fx_import
|
||||||
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||||
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
||||||
def test_stateless_fx_import():
|
def test_stateless_fx_import():
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
--pre
|
--pre
|
||||||
torchvision==0.19.0.dev20240422
|
torchvision==0.19.0.dev20240428
|
||||||
|
|
Loading…
Reference in New Issue