mirror of https://github.com/llvm/torch-mlir
Merge remote-tracking branch 'upstream/main' into emit_detach_
commit
99e2f143f0
|
@ -11,7 +11,7 @@ repos:
|
|||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
---
|
||||
|
||||
extends: default
|
||||
|
||||
rules:
|
||||
# These do not appear to be conventional in GitHub actions.
|
||||
document-end:
|
||||
present: false
|
||||
document-start:
|
||||
present: false
|
||||
# GitHub actions use "on" for triggers.
|
||||
truthy: disable
|
||||
# We have lots of long strings and command lines.
|
||||
line-length: disable
|
||||
comments:
|
||||
# Formatters may do this (e.g. Prettier does) and it seems like the most
|
||||
# trivial thing to get a failing check for.
|
||||
min-spaces-from-content: 1
|
||||
# This is not a useful check, especially when disabling entire blocks.
|
||||
comments-indentation: disable
|
||||
|
||||
ignore: /third_party/*
|
|
@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
|
|||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
|
||||
-DLLVM_TARGETS_TO_BUILD=host \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DTORCH_MLIR_ENABLE_LTC=ON
|
||||
echo "::endgroup::"
|
||||
|
||||
echo "::group::Build"
|
||||
|
|
|
@ -432,6 +432,8 @@ function clean_build() {
|
|||
}
|
||||
|
||||
function build_torch_mlir() {
|
||||
# Disable LTC build for releases
|
||||
export TORCH_MLIR_ENABLE_LTC=0
|
||||
local torch_version="$1"
|
||||
case $torch_version in
|
||||
nightly)
|
||||
|
@ -440,7 +442,7 @@ function build_torch_mlir() {
|
|||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \
|
||||
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \
|
||||
-r /main_checkout/torch-mlir/whl-requirements.txt
|
||||
;;
|
||||
|
@ -450,7 +452,7 @@ function build_torch_mlir() {
|
|||
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
|
||||
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
|
||||
;;
|
||||
*)
|
||||
echo "Unrecognized torch version '$torch_version'"
|
||||
|
@ -474,7 +476,7 @@ function build_torch_mlir_core() {
|
|||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
|
||||
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
|
||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
|
||||
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
|
||||
}
|
||||
|
||||
function clean_wheels() {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
See https://github.com/llvm/torch-mlir/issues/1374
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit a952c123880eb1168f1021b116485e27170d48ca
|
||||
Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf
|
|
@ -1 +1 @@
|
|||
Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97
|
||||
Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91
|
|
@ -30,8 +30,6 @@ namespace detail {
|
|||
LogicalResult verifyTMTensorOpInterface(Operation *op);
|
||||
}
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
|
||||
|
||||
|
@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op);
|
|||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||
|
|
|
@ -97,6 +97,31 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
|
||||
for (auto result : op->getResults()) {
|
||||
auto t = toValidTensorType(result.getType());
|
||||
if (!t)
|
||||
return failure();
|
||||
typeList.push_back(t);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||
// op.
|
||||
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||
if (idx >= op->getNumRegions())
|
||||
return failure();
|
||||
|
||||
region = &op->getRegion(idx);
|
||||
|
||||
if (region == nullptr) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
||||
int64_t idx) {
|
||||
if (idx >= op->getNumResults())
|
||||
|
|
|
@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder,
|
|||
|
||||
Type getQTorchTypeFromTorchIntType(Type ty);
|
||||
|
||||
template <typename T>
|
||||
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
Value &ofItem) {
|
||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||
rewriter.getType<T>(), ofItem);
|
||||
}
|
||||
|
||||
LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
|
|
|
@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get a tensor that collapse the specified dimensions of the input tensor
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get a tensor that splits the specified dimensions of the input tensor
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits);
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType);
|
||||
|
|
|
@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenCeluOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenCelu_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
@ -6590,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -6645,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
|
|||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
||||
|
@ -13601,6 +13677,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTorchStringType:$l,
|
||||
Torch_StringType:$item
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -15904,6 +16005,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
||||
|
|
|
@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
|
|||
return detail::torch_list_of_constant_bools_op_binder(bind_values);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the constant strs stored in a `torch.ListConstruct`.
|
||||
struct torch_list_of_constant_strs_op_binder {
|
||||
SmallVectorImpl<std::string> &bind_values;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
|
||||
: bind_values(bvs) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||
if (!listConstruct)
|
||||
return false;
|
||||
for (Value value : listConstruct.getElements()) {
|
||||
std::string str;
|
||||
if (matchPattern(value, m_TorchConstantStr(str)))
|
||||
bind_values.push_back(str);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
|
||||
inline detail::torch_list_of_constant_strs_op_binder
|
||||
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
|
||||
return detail::torch_list_of_constant_strs_op_binder(bind_values);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
||||
struct torch_tensor_size_int_op_binder {
|
||||
|
|
|
@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
LogicalResult windowFunctionImpl(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value size, Value a0, Value a1, Value a2,
|
||||
Torch::ValueTensorType resultType,
|
||||
int64_t output_datatype, int64_t periodic) {
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
ImplicitLocOpBuilder b(loc, rewriter);
|
||||
|
||||
double isPeriodicFp = static_cast<double>(periodic);
|
||||
|
||||
Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
|
||||
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
|
||||
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));
|
||||
|
||||
constexpr double pi = llvm::numbers::pi;
|
||||
Value tau = b.create<Torch::ConstantFloatOp>(
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
|
||||
|
||||
Value noneVal = b.create<Torch::ConstantNoneOp>();
|
||||
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
|
||||
Value float32Type = b.create<Torch::ConstantIntOp>(
|
||||
rewriter.getI64IntegerAttr(/*float32Type*/ 6));
|
||||
|
||||
// Create an f32 ValueTensorType with thse same size as size, the
|
||||
// operand
|
||||
auto shapeOfOperand =
|
||||
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
|
||||
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
shapeOfOperand, rewriter.getF32Type());
|
||||
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
|
||||
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
|
||||
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
|
||||
periodicSizeFloat.getType(), periodicSizeFloat, one, one);
|
||||
|
||||
Value isPeriodic =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
|
||||
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
|
||||
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
|
||||
|
||||
Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
|
||||
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
|
||||
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
|
||||
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
|
||||
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
|
||||
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);
|
||||
|
||||
// Here, size can be used in the place of periodicSizeFloat, as the
|
||||
// latter is just a float representation of the former.
|
||||
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
|
||||
|
||||
Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
|
||||
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);
|
||||
|
||||
Value rangeTimesTau =
|
||||
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
|
||||
Value rangeAngular =
|
||||
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
|
||||
Value twoRangeAngular =
|
||||
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);
|
||||
|
||||
Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
|
||||
Value cosTwoRangeAngular =
|
||||
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);
|
||||
|
||||
Value a1Component =
|
||||
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
|
||||
Value a2Component =
|
||||
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);
|
||||
|
||||
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
|
||||
// is, operand #1. Therefore, to avoid errors, the onnx implementation
|
||||
// has been modified. a1 has been changed to negative half, and the
|
||||
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
|
||||
// operation is commutative.
|
||||
Value subA1Component =
|
||||
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
|
||||
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
|
||||
a2Component, one);
|
||||
|
||||
std::optional<int64_t> dtypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(output_datatype);
|
||||
if (!dtypeIntTorch.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for the given dtype conversion");
|
||||
}
|
||||
Value outputDtype = b.create<Torch::ConstantIntOp>(
|
||||
rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
dtypeIntTorch.value()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||
binder.op, resultType, result, outputDtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/noneVal);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Simple rewrites for the default domain.
|
||||
// See: https://onnx.ai/onnx/operators/
|
||||
// For operators that are effectively version invariant, we register with
|
||||
|
@ -2186,4 +2288,65 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"BlackmanWindow", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value size;
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t periodic, output_datatype;
|
||||
if (binder.tensorOperand(size) ||
|
||||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
|
||||
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
|
||||
|
||||
auto windowFunctionResult =
|
||||
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||
output_datatype, periodic);
|
||||
|
||||
if (failed(windowFunctionResult))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"HannWindow", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value size;
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t periodic, output_datatype;
|
||||
if (binder.tensorOperand(size) ||
|
||||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
|
||||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
|
||||
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
|
||||
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
||||
|
||||
auto windowFunctionResult =
|
||||
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
|
||||
output_datatype, periodic);
|
||||
|
||||
if (failed(windowFunctionResult))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
alignCorners);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value conditionTensor;
|
||||
if (binder.tensorOperand(conditionTensor)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"condition bind failure");
|
||||
}
|
||||
|
||||
auto conditionType =
|
||||
conditionTensor.getType().cast<Torch::ValueTensorType>();
|
||||
if (!conditionType || conditionType.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "condition must have one single element per "
|
||||
"https://onnx.ai/onnx/operators/onnx__If.html");
|
||||
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
conditionTensor);
|
||||
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
|
||||
|
||||
llvm::SmallVector<mlir::Type> resultTypes;
|
||||
if (binder.tensorResultTypes(resultTypes)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"result type bind failure");
|
||||
}
|
||||
|
||||
Region *thenRegion, *elseRegion;
|
||||
if (binder.getRegionAtIndex(elseRegion, 0) ||
|
||||
binder.getRegionAtIndex(thenRegion, 1)) {
|
||||
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
|
||||
}
|
||||
|
||||
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
|
||||
binder.getLoc(), TypeRange(resultTypes), conditionBool);
|
||||
|
||||
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||
};
|
||||
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
|
||||
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
|
||||
|
||||
auto replaceTerminator = [&](Region ®ion) {
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = region.front().getTerminator();
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
|
||||
terminator, terminator->getOperands());
|
||||
};
|
||||
replaceTerminator(primIfOp.getThenRegion());
|
||||
replaceTerminator(primIfOp.getElseRegion());
|
||||
|
||||
rewriter.replaceOp(binder.op, primIfOp.getResults());
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Less", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c;
|
|||
// thing here, so we simplify.
|
||||
|
||||
// utilities
|
||||
// Templatized function to get an item op of a type
|
||||
namespace {
|
||||
template <typename T>
|
||||
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
Value &ofItem) {
|
||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||
rewriter.getType<T>(), ofItem);
|
||||
}
|
||||
|
||||
// In case the ReduceSum Op was not the first operation performed on the data,
|
||||
// we provide the original operand through storeResult, which will be modified
|
||||
// if the result will be passed onto another operation, and will be used for
|
||||
|
@ -847,12 +839,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
patterns.onOp(
|
||||
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
|
||||
Torch::ValueTensorType resultType;
|
||||
float alpha, gamma;
|
||||
Value operand;
|
||||
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
|
||||
// alpha and gamma values.
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.f32FloatAttr(alpha, "alpha") ||
|
||||
binder.f32FloatAttr(gamma, "gamma") ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
|
||||
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
|
@ -945,22 +940,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*memory_format=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
|
||||
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp("ReduceLogSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
@ -987,6 +966,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, data);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
|
||||
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp("ReduceSumSquare", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
|
||||
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
|
||||
binder.getLoc(), data.getType(), data, data);
|
||||
|
||||
return reducedSumImpl(binder, rewriter, dataSquare,
|
||||
resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceMean", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
|||
if (!isUnsignedType)
|
||||
return;
|
||||
int64_t minSI = -(1 << (numBits - 1));
|
||||
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
|
||||
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
||||
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
|
||||
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
||||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
|
@ -797,6 +798,8 @@ public:
|
|||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
|
||||
Value inputZp, weightZp;
|
||||
bool inputUnsigned = false;
|
||||
bool weightUnsigned = false;
|
||||
if (auto make = op.getInput()
|
||||
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
input = make.getSelf();
|
||||
|
@ -806,6 +809,8 @@ public:
|
|||
inputZp = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(inputZp.getType()),
|
||||
inputZp);
|
||||
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
||||
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
||||
}
|
||||
|
||||
if (auto make = op.getWeight()
|
||||
|
@ -818,6 +823,8 @@ public:
|
|||
weightZp = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(weightZp.getType()),
|
||||
weightZp);
|
||||
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
||||
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
||||
}
|
||||
|
||||
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
|
||||
|
@ -916,15 +923,35 @@ public:
|
|||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
// convert any uint8 quantization to int8 quantization
|
||||
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
|
||||
int64_t width = integerType.getWidth();
|
||||
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
|
||||
}
|
||||
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
|
||||
int64_t width = integerType.getWidth();
|
||||
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
|
||||
}
|
||||
// Pad the input tensor according to padding.
|
||||
SmallVector<Value> outDims{inBatch, weightBatch};
|
||||
Value paddedInput;
|
||||
if (transposed) {
|
||||
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
|
||||
!isa<mlir::FloatType>(resultDTy))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "transpose does not support non-fp type yet");
|
||||
Value pad = inputZp;
|
||||
if (!pad) {
|
||||
if (isa<mlir::FloatType>(inputDTy))
|
||||
pad = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
|
||||
if (isa<mlir::IntegerType>(inputDTy))
|
||||
pad = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
|
||||
}
|
||||
if (pad.getType() != inputDTy) {
|
||||
if (isa<mlir::FloatType>(inputDTy))
|
||||
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
|
||||
|
||||
if (isa<mlir::IntegerType>(inputDTy))
|
||||
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
|
||||
}
|
||||
if (transposed) {
|
||||
Value c0 =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value c1 =
|
||||
|
@ -994,7 +1021,7 @@ public:
|
|||
|
||||
// Allocate padded input tensor
|
||||
Value initTensor =
|
||||
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy);
|
||||
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
|
||||
|
||||
// Insert input into allocated tensor
|
||||
SmallVector<Value> strideIndexValues{c1, c1};
|
||||
|
@ -1017,24 +1044,6 @@ public:
|
|||
strideInts.clear();
|
||||
strideInts.append(numSpatialDims, 1);
|
||||
} else {
|
||||
Value pad = inputZp;
|
||||
if (!pad) {
|
||||
if (isa<mlir::FloatType>(inputDTy))
|
||||
pad = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
|
||||
if (isa<mlir::IntegerType>(inputDTy))
|
||||
pad = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
|
||||
}
|
||||
|
||||
if (pad.getType() != inputDTy) {
|
||||
if (isa<mlir::FloatType>(inputDTy))
|
||||
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
|
||||
|
||||
if (isa<mlir::IntegerType>(inputDTy))
|
||||
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
|
||||
}
|
||||
|
||||
// Pad input
|
||||
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
|
||||
|
|
|
@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||
AtenCumsumOp>(op)) {
|
||||
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
|
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
}
|
||||
|
||||
// Max pooling
|
||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||
AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
|
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// AtenMaxPool2dOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
"max_pooling2d only supports inputs with rank higher than 2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.getKernelSize(),
|
||||
m_TorchListOfConstantInts(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArg = block.args_begin();
|
||||
auto secondArg = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result =
|
||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
|
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT, int Dim>
|
||||
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy = cast<RankedTensorType>(
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
return op.emitError(
|
||||
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||
}
|
||||
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.getKernelSize(),
|
||||
m_TorchListOfConstantInts(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
if (stride.empty()) {
|
||||
stride = kernelSize;
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||
// as input
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloDilation.begin() + inputRank - Dim);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - Dim);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - Dim);
|
||||
|
||||
Value initVal =
|
||||
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
if (Dim == 1) {
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
} else if (Dim == 2) {
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
} else if (Dim == 3) {
|
||||
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
block.addArgument(blockArgumentType, op->getLoc());
|
||||
block.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = block.args_begin();
|
||||
auto secondArg = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
|
||||
*secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||
|
@ -375,8 +404,8 @@ public:
|
|||
auto outShape = outTy.getShape();
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
return op.emitError(
|
||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
|
||||
"higher than 1/2/3");
|
||||
}
|
||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||
bool ceilMode = false;
|
||||
|
@ -405,6 +434,10 @@ public:
|
|||
op, "non-const bool count_include_pad unsupported!");
|
||||
}
|
||||
|
||||
if (stride.empty()) {
|
||||
stride = kernelSize;
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -425,11 +458,20 @@ public:
|
|||
if (Dim == 1) {
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
} else {
|
||||
} else if (Dim == 2) {
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
} else if (Dim == 3) {
|
||||
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
|
||||
Value initVal =
|
||||
|
@ -474,10 +516,17 @@ public:
|
|||
divisor =
|
||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||
.value();
|
||||
} else {
|
||||
} else if (Dim == 2) {
|
||||
divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
} else if (Dim == 3) {
|
||||
divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op,
|
||||
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
|
||||
.value();
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
|
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context, options);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||
#undef INSERT_ATEN_POOLING_PATTERN
|
||||
|
||||
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
|
||||
#undef INSERT_ATEN_MAXPOOL_PATTERN
|
||||
|
||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
|
||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
.getResult();
|
||||
}
|
||||
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits) {
|
||||
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
int64_t rank = dimSizes.size();
|
||||
|
||||
collapseStartDim = toPositiveDim(collapseStartDim, rank);
|
||||
collapseEndDim = toPositiveDim(collapseEndDim, rank);
|
||||
|
||||
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
newDimSizes.reserve(newRank);
|
||||
newShape.reserve(newRank);
|
||||
|
||||
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
int64_t collapseShape = 1;
|
||||
|
||||
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
||||
if (k < 0 || k >= rank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "collapse dimensions must be within the rank of the tensor");
|
||||
}
|
||||
if (collapseShape == ShapedType::kDynamic ||
|
||||
oldShape[k] == ShapedType::kDynamic) {
|
||||
collapseShape = ShapedType::kDynamic;
|
||||
} else {
|
||||
collapseShape *= oldShape[k];
|
||||
}
|
||||
collapseDimSize =
|
||||
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < collapseStartDim; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
newDimSizes.push_back(collapseDimSize);
|
||||
newShape.push_back(collapseShape);
|
||||
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// TODO: support splitDim & outerLength to be Value
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits) {
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
int64_t rank = dimSizes.size();
|
||||
splitDim = toPositiveDim(splitDim, rank);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
if (splitDim < 0 || splitDim >= rank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split dimensions must be within the rank of the tensor");
|
||||
}
|
||||
|
||||
int64_t newRank = rank + 1;
|
||||
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, outerLength));
|
||||
|
||||
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
||||
loc, dimSizes[splitDim], outerLengthValue);
|
||||
|
||||
int64_t originShape = oldShape[splitDim];
|
||||
int64_t outerShape = outerLength;
|
||||
int64_t innerShape = originShape == ShapedType::kDynamic
|
||||
? ShapedType::kDynamic
|
||||
: originShape / outerLength;
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
|
||||
newDimSizes.reserve(newRank);
|
||||
newShape.reserve(newRank);
|
||||
|
||||
for (int64_t k = 0; k < splitDim; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
newDimSizes.push_back(outerLengthValue);
|
||||
newShape.push_back(outerShape);
|
||||
newDimSizes.push_back(innerLengthValue);
|
||||
newShape.push_back(innerShape);
|
||||
|
||||
for (int64_t k = splitDim + 1; k < rank; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType) {
|
||||
|
|
|
@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant end is currently supported");
|
||||
|
||||
start = toPositiveDim(start, rank);
|
||||
end = toPositiveDim(end, rank);
|
||||
SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(rank);
|
||||
for (int r = 0; r < start; ++r)
|
||||
dims.push_back(r);
|
||||
int64_t collapsedDimSize = 1;
|
||||
for (int r = start; r <= end; ++r) {
|
||||
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the size of the dimension being collapsed is can't be unknown");
|
||||
collapsedDimSize *= selfType.getShape()[r];
|
||||
}
|
||||
dims.push_back(collapsedDimSize);
|
||||
for (int r = end + 1; r < rank; ++r)
|
||||
dims.push_back(r);
|
||||
auto collapseTensorInfo = hlo::collapseTensor(
|
||||
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
|
||||
if (failed(collapseTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
|
||||
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
rewriter.replaceOp(op, *collapseTensorInfo);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
||||
auto rank = selfType.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
||||
stablehloShape);
|
||||
op, "the rank of tensor must be greater than 0");
|
||||
|
||||
int64_t dim, outerLength;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant outerLength is currently supported");
|
||||
|
||||
auto splitTensorInfo = hlo::splitTensor(
|
||||
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
|
||||
|
||||
if (failed(splitTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
|
||||
|
||||
rewriter.replaceOp(op, *splitTensorInfo);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
||||
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
||||
auto dtypeElemType = dtypeComplex.getElementType();
|
||||
|
||||
// Extract the real and imaginary parts of the scalar.
|
||||
// Cast them to the target element type, and create a new complex
|
||||
// value with the target complex type.
|
||||
Value realVal = b.create<complex::ReOp>(loc, scalar);
|
||||
Value imgVal = b.create<complex::ImOp>(loc, scalar);
|
||||
|
||||
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
|
||||
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
|
||||
|
||||
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||
}
|
||||
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
<< "(dtype)";
|
||||
}
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
|
|
|
@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
|
|||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||
bool hasTensorCastOperand =
|
||||
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||
if (opOperand->get().isa<BlockArgument>())
|
||||
if (isa<BlockArgument>(opOperand->get()))
|
||||
return false;
|
||||
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
return castOp && canFoldIntoConsumerOp(castOp);
|
||||
|
|
|
@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||
if (!inputDtype || !inputDtype.isInteger(64))
|
||||
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
|
||||
return nullptr;
|
||||
|
||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||
|
@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
if (inputDtype.isInteger(64)) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
} else {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
.getSplatValue<bool>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
}
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
return primNumToTensorScalarOp.getA();
|
||||
|
@ -2385,6 +2394,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten__Contains__StrListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
|
||||
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
|
||||
if (!item)
|
||||
return nullptr;
|
||||
|
||||
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||
if (isListPotentiallyMutated(listConstruct))
|
||||
return nullptr;
|
||||
}
|
||||
llvm::SmallVector<std::string> strs;
|
||||
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
|
||||
for (const auto &str : strs) {
|
||||
if (item.getValue().str() == str)
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
}
|
||||
return getI1IntegerAttr(getContext(), false);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4682,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimsConvertElementTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputType = cast<BaseTensorType>(getA().getType());
|
||||
auto outputType = cast<BaseTensorType>(getResult().getType());
|
||||
if (inputType != outputType)
|
||||
return nullptr;
|
||||
if (!inputType.hasDtype() || !outputType.hasDtype())
|
||||
return nullptr;
|
||||
if (inputType.getDtype() != outputType.getDtype())
|
||||
return nullptr;
|
||||
return getA();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
|
||||
if (!op.getResult1().use_empty()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "result1 of MaxPool2dWithIndices should be unused");
|
||||
}
|
||||
|
||||
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
|
||||
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
||||
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
||||
op.getCeilMode());
|
||||
|
||||
op.getResult0().replaceAllUsesWith(result);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLinalgCrossOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -7841,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
||||
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %int-3 = torch.constant.int -3\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
|
@ -7936,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %23 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10480,6 +10488,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
|
|
@ -1059,44 +1059,44 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
int64_t n;
|
||||
|
||||
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: n must be constant");
|
||||
int64_t m;
|
||||
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: m must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
if (!outType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
if (n < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
||||
}
|
||||
if (m < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
||||
}
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto context = op.getContext();
|
||||
auto int64Dtype = getDtypeIntValueForType(
|
||||
rewriter, loc,
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
||||
|
||||
int64_t n = kUnknownSize;
|
||||
int64_t m = kUnknownSize;
|
||||
// prioritize getting shape from output shape
|
||||
if (outType.hasSizes() && outType.getSizes().size() == 2) {
|
||||
n = outType.getSizes().front();
|
||||
m = outType.getSizes().back();
|
||||
}
|
||||
// if output shape is not available, try to get shape from input
|
||||
if (n == kUnknownSize)
|
||||
matchPattern(op.getN(), m_TorchConstantInt(&n));
|
||||
if (m == kUnknownSize)
|
||||
matchPattern(op.getM(), m_TorchConstantInt(&m));
|
||||
|
||||
// prepare two unsqueezed ranges that are equal on and only on the diagonal
|
||||
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
|
||||
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
|
||||
Value rangeN = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
||||
|
||||
auto arangeType1 =
|
||||
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
||||
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
|
||||
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
|
||||
Value rangeM = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -1109,7 +1109,6 @@ public:
|
|||
}
|
||||
Value unsqzRangeN = *unsqzTensorInfo;
|
||||
|
||||
// compare unsqueezed input with boundaries
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
|
@ -2415,6 +2414,50 @@ public:
|
|||
|
||||
} // namespace
|
||||
|
||||
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||
namespace {
|
||||
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCeluOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value alpha = op.getAlpha();
|
||||
auto resType = cast<BaseTensorType>(op.getType());
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
Value constantZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constantOne =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
|
||||
// positiveOutput = max(0,x)
|
||||
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
||||
Value positiveOutput =
|
||||
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
||||
|
||||
// negativeOutput = min(0,alpha∗(exp(x/alpha)−1))
|
||||
Value scaledInput =
|
||||
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
|
||||
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
|
||||
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
|
||||
constantOne, constantOne);
|
||||
Value scaledExpXM1 =
|
||||
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
|
||||
Value negativeOutput =
|
||||
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
|
||||
Value celuOutput = rewriter.create<AtenAddTensorOp>(
|
||||
loc, resType, positiveOutput, negativeOutput, constantOne);
|
||||
|
||||
rewriter.replaceOp(op, celuOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
|
||||
public:
|
||||
|
@ -7705,6 +7748,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||
|
|
|
@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenPreluOp>();
|
||||
target.addIllegalOp<AtenCeluOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
target.addIllegalOp<AtenToPrimDeviceOp>();
|
||||
|
|
|
@ -272,6 +272,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
# Dynamo not supporting conv_tbc
|
||||
"ConvTbcModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
|
@ -372,6 +373,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
|
@ -544,6 +546,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
|
@ -572,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseLogitModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
|
@ -678,11 +682,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
|
@ -951,6 +950,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseCeluStaticModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
|
@ -1079,6 +1079,9 @@ STABLEHLO_PASS_SET = {
|
|||
"Matmul_vecmat",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
"MeanDimAllReduceModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
|
@ -1156,6 +1159,8 @@ STABLEHLO_PASS_SET = {
|
|||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
|
@ -1239,6 +1244,7 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceWholeTensorModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
|
@ -1571,6 +1577,8 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseBitwiseXorModule_basic",
|
||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseCeluModule_basic",
|
||||
"ElementwiseCeluStaticModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampModule_basic",
|
||||
|
@ -1916,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
|
@ -2096,6 +2099,7 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
}
|
||||
|
||||
ONNX_XFAIL_SET = {
|
||||
|
@ -2121,7 +2125,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseSeluModule_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
|
@ -2251,6 +2254,7 @@ ONNX_XFAIL_SET = {
|
|||
"Conv2dWithPaddingModule_basic",
|
||||
"Conv3dModule_basic",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvTranspose2DQInt8_basic",
|
||||
"Conv_Transpose2dModule_basic",
|
||||
"Convolution2DModule_basic",
|
||||
"Convolution2DStridedModule_basic",
|
||||
|
@ -2306,6 +2310,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
|
@ -2554,16 +2559,12 @@ ONNX_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
# Failure - onnx_import
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
# Failure - onnx_lowering: onnx.If
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
"DiagonalModule_with_dims",
|
||||
"DiagonalModule_with_dims_and_offset",
|
||||
"DiagonalModule_with_negative_dims",
|
||||
"DiagonalModule_with_offset",
|
||||
# these diagonal modules are currently failing due to dynamic shape.
|
||||
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
|
||||
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
|
||||
"TileBigDimsSizeModule_basic",
|
||||
"TileSmallDimsSizeModule_basic",
|
||||
# Failure - onnx_lowering: onnx.MaxPool
|
||||
|
@ -2634,8 +2635,6 @@ ONNX_XFAIL_SET = {
|
|||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
|
|
|
@ -3,6 +3,7 @@ from torch_mlir import torchscript
|
|||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
||||
# Wrap the bert model to avoid multiple returns problem
|
||||
class BertTinyWrapper(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
@ -257,9 +257,9 @@ class _FXGraphImporter:
|
|||
# FakeTensor's in case of a tuple return with multiple elements.
|
||||
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
|
||||
self._module = ir.Module.create(ir.Location.unknown())
|
||||
self._module.operation.attributes[
|
||||
"torch.debug_module_name"
|
||||
] = ir.StringAttr.get(func_name)
|
||||
self._module.operation.attributes["torch.debug_module_name"] = (
|
||||
ir.StringAttr.get(func_name)
|
||||
)
|
||||
function_type = _extract_function_type_from_graph(g)
|
||||
func = func_dialect.FuncOp(
|
||||
func_name,
|
||||
|
|
|
@ -526,6 +526,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
|
|||
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -958,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
|||
|
||||
# TODO: This should be upstreamed.
|
||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
||||
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
||||
def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
|
||||
assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
|
||||
kL = kernel_size[0]
|
||||
|
||||
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
||||
assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
|
||||
dL = kL if len(stride) == 0 else stride[0]
|
||||
|
||||
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
||||
assert len(padding) == 1, "pool1d: padding must be a single int"
|
||||
padL = padding[0]
|
||||
|
||||
dilationL = 1
|
||||
|
@ -1001,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
|
|||
return shape
|
||||
|
||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
|
||||
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
||||
return adaptive_avg_pool1d(self, output_size)
|
||||
|
@ -2652,6 +2658,11 @@ def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu
|
|||
assert self_dtype == weight_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.))
|
||||
def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
(ns, unqual + "_", overload if not is_functional_op else "")
|
||||
),
|
||||
emitter_td,
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"]
|
||||
if not is_functional_op
|
||||
else [],
|
||||
traits=(
|
||||
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
|
||||
),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
|
@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::real : (Tensor) -> (Tensor)")
|
||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||
|
@ -590,9 +591,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
|
@ -973,6 +976,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::format : (...) -> (str)")
|
||||
emit("aten::join : (str, str[]) -> (str)")
|
||||
emit("aten::warn : (str, int) -> ()")
|
||||
emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True)
|
||||
|
||||
# Type conversion ops.
|
||||
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
||||
|
@ -1101,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# `prims::` namespace.
|
||||
# ==========================================================================
|
||||
|
||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
|
||||
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
|
|||
examples = []
|
||||
input_names = []
|
||||
dynamic_tensors = {}
|
||||
for (index, arg) in enumerate(inputs):
|
||||
for index, arg in enumerate(inputs):
|
||||
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
|
||||
shape = tuple(shape)
|
||||
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||
|
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
|
|||
input_names.append(input_name)
|
||||
|
||||
dynamic_dims = {}
|
||||
for (dimindex, dim) in enumerate(arg.shape):
|
||||
for dimindex, dim in enumerate(arg.shape):
|
||||
if dim < 0:
|
||||
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||
|
||||
|
|
|
@ -101,10 +101,12 @@ class RefBackendInvoker:
|
|||
def consume_return_funcs(*args):
|
||||
self.result = tuple(
|
||||
[
|
||||
arg
|
||||
if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type]
|
||||
(
|
||||
arg
|
||||
if type in elemental_type_to_ctype
|
||||
else unranked_memref_to_numpy(
|
||||
arg, memref_type_to_np_dtype[type]
|
||||
)
|
||||
)
|
||||
for arg, type in zip(args, ret_types)
|
||||
]
|
||||
|
@ -178,6 +180,7 @@ LOWERING_PIPELINE = (
|
|||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
"func.func(convert-linalg-to-loops)",
|
||||
"func.func(expand-realloc)",
|
||||
"func.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
"func.func(refback-expand-ops-for-llvm)",
|
||||
|
@ -191,6 +194,7 @@ LOWERING_PIPELINE = (
|
|||
"convert-bufferization-to-memref",
|
||||
"finalize-memref-to-llvm",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
"convert-vector-to-llvm",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"convert-complex-to-llvm",
|
||||
|
|
|
@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
|||
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||
bias = torch.rand(3)
|
||||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
N = 10
|
||||
Cin = 5
|
||||
Cout = 7
|
||||
Hin = 10
|
||||
Win = 8
|
||||
Hker = 3
|
||||
Wker = 2
|
||||
|
||||
|
||||
class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1, -1, -1, -1], torch.int8, True),
|
||||
([-1], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input, weight, bias):
|
||||
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
|
||||
qinput = torch.dequantize(qinput)
|
||||
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
|
||||
qweight = torch.dequantize(qweight)
|
||||
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
|
||||
qbias = torch.dequantize(qbias)
|
||||
qz = torch.ops.aten.convolution(
|
||||
qinput,
|
||||
qweight,
|
||||
bias=qbias,
|
||||
stride=[2, 1],
|
||||
padding=[1, 1],
|
||||
dilation=[1, 1],
|
||||
transposed=True,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
return qz
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
||||
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
|
||||
torch.rand(Cout),
|
||||
)
|
||||
|
|
|
@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalWithStaticShapeModule(torch.nn.Module):
|
||||
"""
|
||||
Diagonal with static shape. The other diagonal modules are failing in onnx
|
||||
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
|
||||
when the shape is static.
|
||||
|
||||
Please remove this module and associated test once the issue is fixed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 9], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.diagonal(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
|
||||
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 9))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DiagonalTransposedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
||||
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
|
||||
)
|
||||
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1016,6 +1014,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCeluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.celu(x, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCeluModule())
|
||||
def ElementwiseCeluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCeluStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.celu(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule())
|
||||
def ElementwiseCeluStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1795,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
# torch.complex32 is not supported by the refbackend.
|
||||
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.complex64, True),
|
||||
([-1], torch.complex128, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
|
||||
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.complex64),
|
||||
tu.randint(4, high=10).type(torch.complex128),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMishModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||
class SliceScatterModule(torch.nn.Module):
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ mb = ModuleBuilder()
|
|||
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
|
||||
# naively duplicating a Tensor retains the identity of the TensorImpl.
|
||||
|
||||
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.add3
|
||||
# Note that line-level debug information for parts unannotated in the Torch
|
||||
# graph are ascribed to the first op that carries source information. Presently
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: @__torch__.f
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.optional_return(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
|
||||
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
|
||||
|
|
|
@ -13,6 +13,7 @@ mb = ModuleBuilder()
|
|||
# else branch and making all defined values optional, so no special handling
|
||||
# is needed.
|
||||
|
||||
|
||||
# CHECK-LABEL: @__torch__.prim_If(
|
||||
# CHECK-SAME: %[[B:.*]]: !torch.bool,
|
||||
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
|
||||
|
|
|
@ -11,6 +11,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
|
||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
|
||||
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
|
||||
|
|
|
@ -15,6 +15,7 @@ import typing
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||
|
|
|
@ -13,6 +13,7 @@ from utils import create_script_function
|
|||
mb = ModuleBuilder()
|
||||
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
|
||||
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.tuple(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: @__torch__.returns_bool
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
|
|||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK: @__torch__.returns_none
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch._C import CompilationUnit
|
|||
|
||||
# RUN: %PYTHON %s
|
||||
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||
cu = CompilationUnit()
|
||||
|
|
|
@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
|
|||
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
||||
|
||||
if _IS_TORCH_2_1_OR_EARLIER:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_numel,
|
||||
}
|
||||
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
||||
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
||||
|
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
|
|||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||
}
|
||||
else:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
}
|
||||
|
||||
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||
else:
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
||||
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
||||
|
@ -264,6 +254,8 @@ else:
|
|||
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
||||
}
|
||||
|
||||
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SparsityMeta:
|
||||
|
@ -1857,8 +1849,7 @@ def _emit_operation(
|
|||
|
||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||
# may have a different meaning.
|
||||
class EmptyType:
|
||||
...
|
||||
class EmptyType: ...
|
||||
|
||||
|
||||
Empty = EmptyType()
|
||||
|
|
|
@ -156,8 +156,7 @@ class GraphInfo:
|
|||
return ""
|
||||
|
||||
|
||||
class OnnxImportError(Exception):
|
||||
...
|
||||
class OnnxImportError(Exception): ...
|
||||
|
||||
|
||||
class NodeImporter:
|
||||
|
@ -235,22 +234,22 @@ class NodeImporter:
|
|||
else:
|
||||
default_opset_version = opset_import.version
|
||||
if default_opset_version:
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.opset_version"
|
||||
] = IntegerAttr.get(i64_type, default_opset_version)
|
||||
container_op.attributes["torch.onnx_meta.opset_version"] = (
|
||||
IntegerAttr.get(i64_type, default_opset_version)
|
||||
)
|
||||
if opset_versions:
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.opset_versions"
|
||||
] = DictAttr.get(opset_versions)
|
||||
container_op.attributes["torch.onnx_meta.opset_versions"] = (
|
||||
DictAttr.get(opset_versions)
|
||||
)
|
||||
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
||||
IntegerType.get_signed(64), m.ir_version
|
||||
)
|
||||
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
||||
m.producer_name
|
||||
)
|
||||
container_op.attributes[
|
||||
"torch.onnx_meta.producer_version"
|
||||
] = StringAttr.get(m.producer_version)
|
||||
container_op.attributes["torch.onnx_meta.producer_version"] = (
|
||||
StringAttr.get(m.producer_version)
|
||||
)
|
||||
|
||||
def import_all(self, func=True):
|
||||
"""Imports all nodes topologically."""
|
||||
|
@ -348,8 +347,14 @@ class NodeImporter:
|
|||
continue
|
||||
elif handler is False:
|
||||
# Active error.
|
||||
# try matching attribute type ID to name for a more descriptive error message
|
||||
try:
|
||||
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
|
||||
except ValueError:
|
||||
attr_type_name = "UNKNOWN"
|
||||
raise OnnxImportError(
|
||||
f"ONNX importer does not support generic node attribute type {attr_type}. "
|
||||
f"ONNX importer does not support generic node attribute type {attr_type_name} "
|
||||
f"with ID {attr_type}. "
|
||||
f"This likely means that this is a special node which requires specific "
|
||||
f"handling in the importer: {onnx_attr}"
|
||||
)
|
||||
|
@ -658,9 +663,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
|
|||
RankedTensorType.get(shape, IntegerType.get_signed(64)),
|
||||
IntegerAttr.get(
|
||||
IntegerType.get_signed(64),
|
||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||
if tp.HasField("raw_data")
|
||||
else tp.int64_data[0],
|
||||
(
|
||||
int.from_bytes(tp.raw_data, "little", signed=True)
|
||||
if tp.HasField("raw_data")
|
||||
else tp.int64_data[0]
|
||||
),
|
||||
),
|
||||
),
|
||||
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
||||
|
@ -703,7 +710,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
|||
),
|
||||
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
||||
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
||||
)
|
||||
),
|
||||
# Intentionally unsupported: STRING
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Optional, Union, Dict, Tuple, Any
|
||||
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
||||
|
||||
import warnings
|
||||
|
||||
|
@ -25,7 +25,7 @@ def export_and_import(
|
|||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
experimental_support_mutation: bool = False,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
decomposition_table: Optional[list] = None,
|
||||
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
func_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
**kwargs,
|
||||
|
|
|
@ -1 +1 @@
|
|||
0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7
|
||||
34ade3521ca41f20af3469bba276c2b0499c3892
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torch==2.4.0.dev20240422
|
||||
torch==2.4.0.dev20240428
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @test_ifop_basic
|
||||
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
|
||||
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: } else {
|
||||
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
|
||||
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
|
||||
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||
}, {
|
||||
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
|
@ -1996,3 +1996,160 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten
|
|||
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
|
||||
return %0 : !torch.vtensor<[3,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
|
||||
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
|
||||
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_blackmanwindow
|
||||
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
|
||||
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
|
||||
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_hannwindow
|
||||
func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||
|
||||
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_hannwindow_symmetric
|
||||
func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
|
||||
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
|
||||
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
|
||||
|
||||
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
|
|
@ -860,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
|
||||
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
||||
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
|
||||
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
|
@ -942,41 +993,24 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example
|
||||
func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
|
||||
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
|
||||
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
|
||||
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
|
||||
func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
|
@ -984,15 +1018,65 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2
|
|||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero
|
||||
func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32>
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
|
||||
// CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
|
||||
return %0 : !torch.vtensor<[2,0,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example
|
||||
func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example
|
||||
func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
|
||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
|
|
|
@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
|
|||
// CHECK-LABEL: func.func @torch.aten.reciprocal(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
||||
|
@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
|
|||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
|
@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
|||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
||||
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
|
@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
|||
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
|
@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
|
|||
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "none"
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
|
@ -487,7 +487,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
|
|||
// CHECK-LABEL: func.func @torch.aten.relu(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
||||
|
@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
|
||||
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: })
|
||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
||||
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
||||
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
|
@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: }) : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||
|
@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
|
||||
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||
|
@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
||||
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
|
||||
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]])
|
||||
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
|
|
|
@ -22,10 +22,10 @@
|
|||
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
|
||||
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
|
||||
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({
|
||||
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
|
||||
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
|
||||
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
|
||||
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false} : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: }) : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
|
||||
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
|
|
|
@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool {
|
|||
|
||||
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
|
||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
|
||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
||||
// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-NEXT: return %[[TRUE]] : !torch.bool
|
||||
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
|
@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
||||
%str4 = torch.constant.str "4"
|
||||
%str5 = torch.constant.str "5"
|
||||
|
@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool {
|
|||
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
|
||||
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
|
||||
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
|
||||
// CHECK-NEXT: return %[[F]] : !torch.bool
|
||||
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-NEXT: return %[[FALSE]] : !torch.bool
|
||||
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
|
||||
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
|
||||
%str4 = torch.constant.str "4"
|
||||
%str4_0 = torch.constant.str "4"
|
||||
|
@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int {
|
|||
return %2 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
|
||||
%str = torch.constant.str "c"
|
||||
%str_0 = torch.constant.str "b"
|
||||
%str_1 = torch.constant.str "a"
|
||||
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
|
||||
%str = torch.constant.str "aa"
|
||||
%str_0 = torch.constant.str "aa"
|
||||
%str_1 = torch.constant.str "ccc"
|
||||
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
|
||||
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
|
||||
return %2 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.__not__
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
|
@ -2950,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
|
|||
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
|
||||
return %result : !torch.vtensor<[4], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
|
||||
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
|
||||
return %0 : !torch.vtensor<[64],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
|
||||
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||
return %0 : !torch.vtensor<[64],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
|
||||
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
|
||||
return %result0 : !torch.vtensor<[10,64,56,56],f32>
|
||||
}
|
||||
|
|
|
@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||
def test_broadcast_with_dynamic_shapes():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.broadcast_to(x, (y.shape[0], -1))
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
y = torch.randn(10)
|
||||
|
||||
dim_0 = Dim("dim_0")
|
||||
dynamic_shapes = {
|
||||
"x": {},
|
||||
"y": {0: dim_0},
|
||||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@make_boxed_compiler
|
||||
def fx_import_aot_autograd_backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
|
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_stateless_fx_import
|
||||
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
||||
def test_stateless_fx_import():
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torchvision==0.19.0.dev20240422
|
||||
torchvision==0.19.0.dev20240428
|
||||
|
|
Loading…
Reference in New Issue