Merge branch 'main' into lower_torch_aten_gcd_to_linalg_and_scf

pull/3732/head
bratislavSyrmia 2024-10-10 22:45:58 +02:00 committed by GitHub
commit ee7f6ee9fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 3131 additions and 513 deletions

View File

@ -50,7 +50,7 @@ TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}"
# Location to store Release wheels # Location to store Release wheels
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
# What "packages to build" # What "packages to build"
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}" TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-ext}"
# Use pre-built Pytorch # Use pre-built Pytorch
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration # Skip running tests if you want quick iteration
@ -83,12 +83,12 @@ function run_on_host() {
fi fi
mkdir -p "${TM_OUTPUT_DIR}" mkdir -p "${TM_OUTPUT_DIR}"
case "$package" in case "$package" in
torch-mlir) torch-mlir-ext)
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
export USERID=0 export USERID=0
export GROUPID=0 export GROUPID=0
;; ;;
torch-mlir-core) torch-mlir)
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
export USERID=0 export USERID=0
export GROUPID=0 export GROUPID=0
@ -158,22 +158,22 @@ function run_in_docker() {
export PATH=$python_dir/bin:$orig_path export PATH=$python_dir/bin:$orig_path
echo ":::: Python version $(python3 --version)" echo ":::: Python version $(python3 --version)"
case "$package" in case "$package" in
torch-mlir) torch-mlir-ext)
clean_wheels torch_mlir "$python_version" clean_wheels torch_mlir_ext "$python_version"
build_torch_mlir "$TM_TORCH_VERSION" build_torch_mlir_ext "$TM_TORCH_VERSION"
# Disable audit wheel until we can fix ODR torch issues. See # Disable audit wheel until we can fix ODR torch issues. See
# https://github.com/llvm/torch-mlir/issues/1709 # https://github.com/llvm/torch-mlir/issues/1709
# #
#run_audit_wheel torch_mlir "$python_version" #run_audit_wheel torch_mlir_ext "$python_version"
clean_build torch_mlir "$python_version" clean_build torch_mlir_ext "$python_version"
;; ;;
torch-mlir-core) torch-mlir)
clean_wheels torch_mlir_core "$python_version" clean_wheels torch_mlir "$python_version"
build_torch_mlir_core build_torch_mlir
run_audit_wheel torch_mlir_core "$python_version" run_audit_wheel torch_mlir "$python_version"
clean_build torch_mlir_core "$python_version" clean_build torch_mlir "$python_version"
;; ;;
out-of-tree) out-of-tree)
setup_venv "$python_version" "$TM_TORCH_VERSION" setup_venv "$python_version" "$TM_TORCH_VERSION"
@ -431,7 +431,7 @@ function clean_build() {
rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch
} }
function build_torch_mlir() { function build_torch_mlir_ext() {
# Disable LTC build for releases # Disable LTC build for releases
export TORCH_MLIR_ENABLE_LTC=0 export TORCH_MLIR_ENABLE_LTC=0
local torch_version="$1" local torch_version="$1"
@ -470,7 +470,9 @@ function run_audit_wheel() {
rm "$generic_wheel" rm "$generic_wheel"
} }
function build_torch_mlir_core() { function build_torch_mlir() {
# Disable LTC build for releases
export TORCH_MLIR_ENABLE_LTC=0
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
CMAKE_GENERATOR=Ninja \ CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \

View File

@ -56,16 +56,16 @@ function run() {
export PATH=$python_dir/bin:$orig_path export PATH=$python_dir/bin:$orig_path
echo ":::: Python version $(python3 --version)" echo ":::: Python version $(python3 --version)"
case "$package" in case "$package" in
torch-mlir-ext)
clean_wheels torch_mlir_ext "$python_version"
build_torch_mlir_ext torch_mlir_ext "$python_version"
run_audit_wheel torch_mlir_ext "$python_version"
;;
torch-mlir) torch-mlir)
clean_wheels torch_mlir "$python_version" clean_wheels torch_mlir "$python_version"
build_torch_mlir torch_mlir "$python_version" build_torch_mlir torch_mlir "$python_version"
run_audit_wheel torch_mlir "$python_version" run_audit_wheel torch_mlir "$python_version"
;; ;;
torch-mlir-core)
clean_wheels torch_mlir_core "$python_version"
build_torch_mlir_core torch_mlir_core "$python_version"
run_audit_wheel torch_mlir_core "$python_version"
;;
*) *)
echo "Unrecognized package '$package'" echo "Unrecognized package '$package'"
exit 1 exit 1
@ -75,7 +75,7 @@ function run() {
done done
} }
function build_torch_mlir() { function build_torch_mlir_ext() {
local wheel_basename="$1" local wheel_basename="$1"
local python_version="$2" local python_version="$2"
rm -rf "$output_dir"/build_venv rm -rf "$output_dir"/build_venv
@ -93,7 +93,7 @@ function build_torch_mlir() {
rm -rf "$output_dir"/build_venv rm -rf "$output_dir"/build_venv
} }
function build_torch_mlir_core() { function build_torch_mlir() {
local wheel_basename="$1" local wheel_basename="$1"
local python_version="$2" local python_version="$2"
rm -rf "$output_dir"/build_venv rm -rf "$output_dir"/build_venv

View File

@ -14,7 +14,7 @@ While this is running, you can already setup the Python venv and dependencies in
## Setup your Python VirtualEnvironment and Dependencies ## Setup your Python VirtualEnvironment and Dependencies
```shell ```shell
python -m venv mlir_venv python3 -m venv mlir_venv
source mlir_venv/bin/activate source mlir_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps # Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip python -m pip install --upgrade pip

@ -1 +1 @@
Subproject commit d418a03e01e6a31b51b0c9dd42ba46da6c47f89d Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit c28d55e91b4a5daaff18a33ce7e9bbd0f171256a Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739

View File

@ -34,6 +34,7 @@ struct OpBinder {
Location getLoc() { return op->getLoc(); } Location getLoc() { return op->getLoc(); }
int getNumOperands() { return op->getNumOperands(); } int getNumOperands() { return op->getNumOperands(); }
int getNumResults() { return op->getNumResults(); }
// Operand matches of different arities. // Operand matches of different arities.
ParseResult tensorOperand(Value &value0) { ParseResult tensorOperand(Value &value0) {
@ -338,6 +339,31 @@ struct OpBinder {
return failure(); return failure();
} }
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
StringRef nameSuffix,
ArrayRef<float> defaults) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
values.append(defaults.begin(), defaults.end());
return success();
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto floatAttr = dyn_cast<FloatAttr>(element);
if (!floatAttr)
return failure();
FloatType t = cast<FloatType>(floatAttr.getType());
if (t.getWidth() != 32)
return failure();
values.push_back(floatAttr.getValue().convertToFloat());
}
return success();
}
return failure();
}
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values, ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
StringRef nameSuffix) { StringRef nameSuffix) {
SmallString<64> name("torch.onnx."); SmallString<64> name("torch.onnx.");

View File

@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions, Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result); Value input, Value &result);
// Flips an input tensor based on the values of axis list.
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
SmallVector<int64_t> axis);
} // namespace torch_to_linalg } // namespace torch_to_linalg
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -10,7 +10,7 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project

View File

@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy); Type elemTy);
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy);
Value castIntToIndex(OpBuilder &b, Location loc, Value v); Value castIntToIndex(OpBuilder &b, Location loc, Value v);

View File

@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
}]; }];
} }
def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$real,
AnyTorchTensorType:$imag
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenComplexOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRealOp : Torch_Op<"aten.real", [ def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly
@ -7078,6 +7102,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
}]; }];
} }
def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -9195,6 +9248,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
}]; }];
} }
def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$pos_weight,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -13637,6 +13717,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [
}]; }];
} }
def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dimension,
Torch_IntType:$size,
Torch_IntType:$step
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUnfoldOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure(); return failure();
auto shapeSizes = shapeType.getSizes(); auto shapeSizes = shapeType.getSizes();
int64_t dataRank = dataType.getSizes().size(); ArrayRef<int64_t> dataShape = dataType.getSizes();
int64_t dataRank = dataShape.size();
int64_t shapeRank = shapeSizes.size(); int64_t shapeRank = shapeSizes.size();
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
return failure(); return failure();
@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// we are using torch implementation Torch::AtenBroadcastToOp which // we are using torch implementation Torch::AtenBroadcastToOp which
// takes list of int // takes list of int
for (int i = 0; i < shapeSizes[0]; i++) { for (int i = 0; i < shapeSizes[0]; i++) {
// extract dim from shape
Value selectIndex = rewriter.create<Torch::ConstantIntOp>( Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>( Value extract = rewriter.create<Torch::AtenSelectIntOp>(
loc, selectResultType, shape, zero, selectIndex); loc, selectResultType, shape, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>( Value selectDim = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), extract); loc, rewriter.getType<Torch::IntType>(), extract);
// compute dim to pass to broadcast op. For non-broadcastable dims,
if (i + rankDifference >= 0) { // pass -1
Value dim;
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
// broadcasted
// 2. we will explicitly disallow broadcasting dynamic dims that are
// secretly 1.
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
// Assert dataShape[i + rankDiff] >= selectDim. If both are
// constant, this should fold out.
Value iv = Value iv =
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference); rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
auto sz = rewriter.create<Torch::AtenSizeIntOp>( auto sz = rewriter.create<Torch::AtenSizeIntOp>(
loc, rewriter.getType<Torch::IntType>(), data, iv); loc, rewriter.getType<Torch::IntType>(), data, iv);
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz); Value gtSelect =
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
rewriter.create<Torch::RuntimeAssertOp>(
loc, gtSelect,
rewriter.getStringAttr(
"onnx.Expand input has a dim that is not statically 1; "
"expected this dim >= dim provided shape."));
} else {
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
// dataRank)
// 2. selectDims which correspond to dataShape == 1 get included in
// broadcast
dim = selectDim;
} }
dimList.push_back(dim); dimList.push_back(dim);
} }
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>( Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(

View File

@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// TODO: Implement max and min cases // TODO: Implement max and min cases
if (reduction == "mul") { if (reduction == "mul") {
reduction = "multiply"; reduction = "prod";
} else if (reduction == "max" || reduction == "min") { } else if (reduction == "max" || reduction == "min") {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "max/min reduction unsupported for scatter elements"); binder.op, "max/min reduction unsupported for scatter elements");
} else if (reduction == "add") {
reduction = "sum";
} }
Value cstStrReduction = Value cstStrReduction =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction); rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
Value cstTrue =
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>( rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
binder.op, resultType, data, constAxis, indices, updates, binder.op, resultType, data, constAxis, indices, updates,
cstStrReduction); cstStrReduction, cstTrue);
return success(); return success();
}); });
patterns.onOp( patterns.onOp(
@ -1662,10 +1665,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto shapeType = Torch::ValueTensorType::get( auto shapeType = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{inputRank}, binder.op->getContext(), SmallVector<int64_t>{inputRank},
resultType.getOptionalDtype()); resultType.getOptionalDtype());
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>( Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
binder.getLoc(), shapeType, operand); binder.getLoc(), shapeType, operand);
if (inputRank == 0) {
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
binder.op, resultType, shape);
return success();
}
if (start == 0 && end == -1) { if (start == 0 && end == -1) {
rewriter.replaceOp(binder.op, shape); rewriter.replaceOp(binder.op, shape);
return success(); return success();
@ -1673,18 +1681,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value sv = rewriter.create<Torch::ConstantIntOp>( Value sv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(start)); binder.getLoc(), rewriter.getI64IntegerAttr(start));
Value ev = rewriter.create<Torch::ConstantIntOp>( Value ev = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(end)); binder.getLoc(), rewriter.getI64IntegerAttr(end));
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1); Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0); Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
shape = rewriter.create<Torch::AtenSliceTensorOp>( rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
binder.getLoc(), resultType, shape, dim, sv, ev, step); binder.op, resultType, shape, dim, sv, ev, step);
rewriter.replaceOp(binder.op, shape);
return success(); return success();
}); });
@ -4339,6 +4342,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> ngram_counts; llvm::SmallVector<int64_t> ngram_counts;
llvm::SmallVector<int64_t> ngram_indexes; llvm::SmallVector<int64_t> ngram_indexes;
llvm::SmallVector<int64_t> pool_int64s; llvm::SmallVector<int64_t> pool_int64s;
llvm::SmallVector<float> weights;
std::string mode; std::string mode;
int64_t min_gram_length; int64_t min_gram_length;
int64_t max_gram_length; int64_t max_gram_length;
@ -4356,9 +4360,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorOperand(input) || binder.tensorResultType(resultType)) binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure(); return failure();
if (mode != "TF") llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
return rewriter.notifyMatchFailure(binder.op, if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
"TF mode supported only"); return failure();
if (pool_int64s.size() == 0) if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported"); binder.op, "pool_int64s empty, only integers supported");
@ -4584,9 +4589,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), loopConditionTrue, ValueRange({count})); binder.getLoc(), loopConditionTrue, ValueRange({count}));
} }
count = skipLoop.getResult(0); count = skipLoop.getResult(0);
// insert count "tf" into output
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>( Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
binder.getLoc(), count); binder.getLoc(), count);
if (mode == "IDF" || mode == "TFIDF") {
// both IDF and TFIDF modes use weights
float weight = weights[ngram_i];
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(weight));
// TFIDF
Value multiplier = countFloat;
if (mode == "IDF") {
// All the counts larger than 1 would be truncated to 1
// and the i-th element in weights would be used to scale
// (by multiplication) the count of the i-th n-gram in pool.
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
binder.getLoc(), count);
// compare intCount > 0
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
binder.getLoc(), intCount, zero);
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), gtZeroCount);
Value gtZeroCountFloat =
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
gtZeroCount);
multiplier = gtZeroCountFloat;
}
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), multiplier, constWeight);
}
Value dataList = rewriter.create<Torch::PrimListConstructOp>( Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), binder.getLoc(),
rewriter.getType<Torch::ListType>( rewriter.getType<Torch::ListType>(

View File

@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
std::string direction; std::string direction;
ValueTensorType yTy, Y_hType, Y_cType; ValueTensorType yTy, Y_hType, Y_cType;
if (binder.tensorResultTypeAtIndex(yTy, 0) || if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1) || binder.tensorResultTypeAtIndex(Y_hType, 1) &&
binder.tensorResultTypeAtIndex(Y_cType, 2)) { binder.tensorResultTypeAtIndex(Y_cType, 2)) {
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"At least one outputs must be present"); "At least one outputs must be present");
@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
auto xTy = cast<ValueTensorType>(X.getType()); auto xTy = cast<ValueTensorType>(X.getType());
auto wTy = cast<ValueTensorType>(W.getType()); auto wTy = cast<ValueTensorType>(W.getType());
Value B;
if (binder.tensorOperandAtIndex(B, 3)) { // TODO: add defaults for activation_alpha acticvation_beta attributes
B = b.create<AtenZerosOp>(W.getType(), W);
}
llvm::SmallVector<std::string> activationsList; llvm::SmallVector<std::string> activationsList;
if (binder.stringArrayAttr(activationsList, "activations")) if (binder.stringArrayAttr(activationsList, "activations"))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Missing required attribute; activations"); binder.op, "Missing required attribute; activations");
LstmActivations activations; if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
activations.f = "Sigmoid"; direction != "forward" && direction != "bidirectional")
activations.g = "Tanh";
activations.h = "Tanh";
if (activationsList.size() == 3) {
activations.f = activationsList[0];
activations.g = activationsList[1];
activations.h = activationsList[2];
} else if (activationsList.size() != 0) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "activations must be empty have 3 elements, but " + binder.op, "Unsupported direction attribute value. "
"Only 'forward' / 'bidrectional' are supported but '" +
direction + "' is provided.");
int64_t num_directions = 1 + (direction == "bidirectional");
bool isBidirectional = direction == "bidirectional";
// There can be backward activations too
// if backward -> look for 6 atcivations (what happens when only three?)
int64_t num_activations = activationsList.size();
if (num_activations != 0 && num_activations != 3 && num_activations != 6) {
return rewriter.notifyMatchFailure(
binder.op, "activations must either be empty (default), have 3 elements"
" (forward) or, have 6 elements (bidirectional), but " +
std::to_string(activationsList.size()) + std::to_string(activationsList.size()) +
" are provided."); " are provided.");
} }
// TODO : Add checks, defaults and fails for inputs - sequence_lens, P and
// attrs- clip, input_forget, layout
if (!binder.customOpNameStringAttr(direction, "direction", "forward") && Value B;
direction != "forward") if (binder.tensorOperandAtIndex(B, 3)) {
Value none = b.create<ConstantNoneOp>();
Value cstHiddenx8 = b.create<ConstantIntOp>(
b.getType<IntType>(), b.getI64IntegerAttr(8 * hidden_size));
Value cstNumDir = b.create<ConstantIntOp>(
b.getType<IntType>(), b.getI64IntegerAttr(num_directions));
auto BType = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, 8 * hidden_size},
cast<ValueTensorType>(W.getType()).getDtype());
Value zerosShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(b.getType<IntType>()),
SmallVector<Value>{cstNumDir, cstHiddenx8});
B = b.create<AtenZerosOp>(BType, zerosShapeList, none, none, none, none);
}
LstmActivations activations, activationsRev;
// Default case (both forward and reverse)
activations.f = "Sigmoid";
activations.g = "Tanh";
activations.h = "Tanh";
activationsRev.f = "Sigmoid";
activationsRev.g = "Tanh";
activationsRev.h = "Tanh";
// forward only (also to be added for bidirectional case)
if (num_activations >= 3) {
activations.f = activationsList[0];
activations.g = activationsList[1];
activations.h = activationsList[2];
}
// bidirectional
if (num_activations == 6) {
activationsRev.f = activationsList[3];
activationsRev.g = activationsList[4];
activationsRev.h = activationsList[5];
}
float clip;
if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0)
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"Unsupported direction attribute value. " "clip attribute not supported");
"Only 'forward' is supported but '" +
direction + "' is provided."); int64_t input_forget;
int64_t num_directions = 1 + (direction == "bidirectional"); if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) &&
input_forget != 0)
return rewriter.notifyMatchFailure(
binder.op, "only input_forget = 0 supported. Got input_forgt = " +
std::to_string(input_forget));
int64_t layout;
if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1)
return rewriter.notifyMatchFailure(
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
std::to_string(layout));
auto XShape = xTy.getSizes(); auto XShape = xTy.getSizes();
int64_t batch_size = XShape[1]; int64_t seq_len, batch_size;
if (layout == 0) {
seq_len = XShape[0];
batch_size = XShape[1];
} else {
seq_len = XShape[1];
batch_size = XShape[0];
}
int64_t input_size = XShape[2]; int64_t input_size = XShape[2];
if (num_directions != wTy.getSizes()[0]) if (num_directions != wTy.getSizes()[0])
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) + binder.op, "num_directions (" + std::to_string(num_directions) +
") does not match the first dimension of wTy (" + ") does not match the first dimension of wTy (" +
std::to_string(wTy.getSizes()[0]) + ")"); std::to_string(wTy.getSizes()[0]) + ")");
if (num_directions != 1)
return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) +
") is not equal to 1");
if (4 * hidden_size != wTy.getSizes()[1]) if (4 * hidden_size != wTy.getSizes()[1])
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) + binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value R_forward = getDirection(b, 0, R); Value R_forward = getDirection(b, 0, R);
Value B_forward = getDirection(b, 0, B); Value B_forward = getDirection(b, 0, B);
Value W_reverse, R_reverse, B_reverse;
if (isBidirectional) {
W_reverse = getDirection(b, 1, W);
R_reverse = getDirection(b, 1, R);
B_reverse = getDirection(b, 1, B);
}
auto hTy = b.getType<ValueTensorType>( auto hTy = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size}, llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
xTy.getDtype()); xTy.getDtype());
@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value initial_h; Value initial_h;
if (binder.tensorOperandAtIndex(initial_h, 5)) { if (binder.tensorOperandAtIndex(initial_h, 5)) {
// default created for layout 0
initial_h = initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1)
initial_h = StaticTranspose(b, initial_h, 0, 1);
} }
Value initial_c; Value initial_c;
if (binder.tensorOperandAtIndex(initial_c, 6)) { if (binder.tensorOperandAtIndex(initial_c, 6)) {
// default created for layout 0
initial_c = initial_c =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1)
initial_c = StaticTranspose(b, initial_c, 0, 1);
} }
// convert X from layout 1 to layout 0
if (layout == 1)
X = StaticTranspose(b, X, 0, 1);
// X, initial_h, initial_c are now in layout 0
Value initial_h_forward = getDirection(b, 0, initial_h); Value initial_h_forward = getDirection(b, 0, initial_h);
Value initial_c_forward = getDirection(b, 0, initial_c); Value initial_c_forward = getDirection(b, 0, initial_c);
if (num_directions != 1) { Value initial_h_reverse, initial_c_reverse;
return rewriter.notifyMatchFailure( if (isBidirectional) {
binder.op, "Unsupported num_directions. Only 1 is supported but " + initial_h_reverse = getDirection(b, 1, initial_h);
std::to_string(num_directions) + " is provided."); initial_c_reverse = getDirection(b, 1, initial_c);
// TODO: support bidirectional LSTM by doing both directions and replacing
// Unsqueeze with Stack
} }
// Everything hereon is for the forward direction, with the direction
// dimention squeezed out.
LstmWeights weights; // weights and biases // Everything hereon is for the forward direction (unless in bidirectional if
// block), with the direction dimention squeezed out and all inputs in layout
// 0 format
LstmWeights weights, weightsRev; // weights and biases
auto intConst = [&](int64_t val) { auto intConst = [&](int64_t val) {
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val)); return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value recurrentWeightsEndIdx = intConst(8 * hidden_size); Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
auto biasType = b.getType<ValueTensorType>( auto biasType = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype()); llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
// forward
Value Wb = b.create<AtenSliceTensorOp>(biasType, Value Wb = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_forward, /*input=*/B_forward,
/*dim=*/cstZero, /*dim=*/cstZero,
@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
/*start=*/recurrentWeightsStartIdx, /*start=*/recurrentWeightsStartIdx,
/*end=*/recurrentWeightsEndIdx, /*end=*/recurrentWeightsEndIdx,
/*step=*/cstOne); /*step=*/cstOne);
Value Wb_reverse, Rb_reverse;
if (isBidirectional) {
// reverse
Wb_reverse = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_reverse,
/*dim=*/cstZero,
/*start=*/cstZero,
/*end=*/inputWeightsEndIdx,
/*step=*/cstOne);
Rb_reverse = b.create<AtenSliceTensorOp>(biasType,
/*input=*/B_reverse,
/*dim=*/cstZero,
/*start=*/recurrentWeightsStartIdx,
/*end=*/recurrentWeightsEndIdx,
/*step=*/cstOne);
}
// gate splitting // gate splitting
auto gateBiasType = b.getType<ValueTensorType>( auto gateBiasType = b.getType<ValueTensorType>(
@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size); Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
Value cellGateWeightsEndIdx = intConst(4 * hidden_size); Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
auto sliceIOFC = [&](std::function<Value(Value, Value)> slicerFunction) { auto sliceIOFC = [&](std::function<Value(Value, Value, Value)> slicerFunction,
Value WoB) {
// slice into 4 components and return tuple // slice into 4 components and return tuple
return std::make_tuple( return std::make_tuple(
slicerFunction(cstZero, inputGateWeightsEndIdx), slicerFunction(cstZero, inputGateWeightsEndIdx, WoB),
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx), slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB),
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx), slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB),
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx)); slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB));
}; };
auto sliceGateBias = [&](Value startIdx, Value endIdx) { auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateBiasType, Wb, cstZero, startIdx, return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
endIdx, cstOne); endIdx, cstOne);
}; };
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
sliceIOFC(sliceGateBias); sliceIOFC(sliceGateBias, Wb);
auto sliceGateBiasR = [&](Value startIdx, Value endIdx) { if (isBidirectional)
return b.create<AtenSliceTensorOp>(gateBiasType, Rb, cstZero, startIdx, std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f,
weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse);
auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
endIdx, cstOne); endIdx, cstOne);
}; };
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
sliceIOFC(sliceGateBiasR); sliceIOFC(sliceGateBiasR, Rb);
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) { if (isBidirectional)
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, W_forward, cstZero, std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f,
weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse);
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, WoB, cstZero,
startIdx, endIdx, cstOne); startIdx, endIdx, cstOne);
}; };
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
sliceIOFC(sliceGateWeightsIH); sliceIOFC(sliceGateWeightsIH, W_forward);
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) { if (isBidirectional)
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, R_forward, cstZero, std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) =
sliceIOFC(sliceGateWeightsIH, W_reverse);
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) {
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, WoB, cstZero,
startIdx, endIdx, cstOne); startIdx, endIdx, cstOne);
}; };
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
sliceIOFC(sliceGateWeightsHH); sliceIOFC(sliceGateWeightsHH, R_forward);
if (isBidirectional)
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
sliceIOFC(sliceGateWeightsHH, R_reverse);
LstmLayerOutput lstmLayerOutput = lstm_layer( LstmLayerOutput lstmLayerOutput = lstm_layer(
b, X, initial_h_forward, initial_c_forward, weights, activations); b, X, initial_h_forward, initial_c_forward, weights, activations);
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>( Value Y_h_result, Y_c_result, Y_result;
// if forward (unidirectional) unsqueeze and output
auto YallDtype =
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
auto Y_uni_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
YallDtype);
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size}, llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype()); YallDtype);
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>( auto Y_res_type = b.getType<ValueTensorType>(
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero); llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>( hidden_size},
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero); YallDtype);
Value Y_h_forward =
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
Value Y_c_forward =
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero);
// unsqueeze num_directions dim1 of Y // unsqueeze num_directions dim1 of Y
// to create the onnx.LSTM output shape [seq_length, num_directions, // to create the onnx.LSTM output shape [seq_length, num_directions,
// batch_size, hidden_size] // batch_size, hidden_size]
Value Y_unsqueezed = Value Y_forward =
b.create<AtenUnsqueezeOp>(yTy, lstmLayerOutput.Y, cstOne); b.create<AtenUnsqueezeOp>(Y_uni_type, lstmLayerOutput.Y, cstOne);
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed, Y_result = Y_forward;
Y_c_unsqueezed}); Y_h_result = Y_h_forward;
Y_c_result = Y_c_forward;
// add bidrectional reverse layer
// this is just flip X, lstm layer, flip results, stack
// flip X
Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped,
Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list;
LstmLayerOutput revLstmLayerOutput;
if (isBidirectional) {
dim0 = b.create<PrimListConstructOp>(b.getType<ListType>(intType),
SmallVector<Value>{cstZero});
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
revLstmLayerOutput =
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
weightsRev, activationsRev);
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
revLstmLayerOutput.Y_h, cstZero);
Y_c_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
revLstmLayerOutput.Y_c, cstZero);
Y_reverse_unflipped =
b.create<AtenUnsqueezeOp>(Y_uni_type, revLstmLayerOutput.Y, cstOne);
// flip Y_rev on dim 0 [seq_len]
Y_reverse = b.create<AtenFlipOp>(Y_uni_type, Y_reverse_unflipped, dim0);
// Concat forward and reverse results on dim 1
Y_output_list =
b.create<PrimListConstructOp>(b.getType<ListType>(Y_uni_type),
SmallVector<Value>{Y_forward, Y_reverse});
Y_result = b.create<AtenCatOp>(Y_res_type, Y_output_list, cstOne);
// Concat forward and reverse results on dim 0
Y_h_output_list = b.create<PrimListConstructOp>(
b.getType<ListType>(Y_h_Y_c_uni_type),
SmallVector<Value>{Y_h_forward, Y_h_reverse});
Y_h_result =
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_h_output_list, cstZero);
Y_c_output_list = b.create<PrimListConstructOp>(
b.getType<ListType>(Y_h_Y_c_uni_type),
SmallVector<Value>{Y_c_forward, Y_c_reverse});
Y_c_result =
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_c_output_list, cstZero);
}
if (layout == 1) {
// Update Y, Y_h, Y_c results to layout 1
Y_result = StaticTranspose(b, Y_result, 1, 2);
Y_result = StaticTranspose(b, Y_result, 0, 1);
Y_h_result = StaticTranspose(b, Y_h_result, 0, 1);
Y_c_result = StaticTranspose(b, Y_c_result, 0, 1);
}
// Only add outputs specified in onnx output node
SmallVector<Value> actualOutputs = {Y_result, Y_h_result, Y_c_result},
outputs;
ValueTensorType resTy;
for (int i = 0; i < binder.getNumResults(); ++i) {
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
outputs.push_back(cstNone);
} else {
outputs.push_back(actualOutputs[i]);
}
}
rewriter.replaceOp(binder.op, outputs);
return success(); return success();
} }
@ -1072,11 +1273,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstNone = b.create<ConstantNoneOp>(); Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0)); Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1)); Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));
// Binding arguments // Binding arguments
ValueTensorType yTy, Y_hType; ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) || if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1)) { binder.tensorResultTypeAtIndex(Y_hType, 1)) {
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"At least one output must be present"); "At least one output must be present");
@ -1132,6 +1332,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// Validations // Validations
auto XShape = xTy.getSizes(); auto XShape = xTy.getSizes();
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0]; int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
int64_t input_size = XShape[2]; int64_t input_size = XShape[2];
std::ostringstream oss; std::ostringstream oss;
@ -1173,6 +1374,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
initial_h = initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1) {
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
} }
if (binder.tensorOperandAtIndex(sequence_lens, 4)) if (binder.tensorOperandAtIndex(sequence_lens, 4))
@ -1192,10 +1397,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// fill in B // fill in B
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
if (B == nullptr) { if (B == nullptr) {
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size}; SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
SmallVector<Value> BShapeListContents = { SmallVector<Value> BShapeListContents = {
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)), b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))}; b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
Value BShapeList = b.create<PrimListConstructOp>( Value BShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(intType), BShapeListContents); b.getType<ListType>(intType), BShapeListContents);
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype()); auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
@ -1256,51 +1461,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
B_slices[4], B_slices[5]); B_slices[4], B_slices[5]);
// Process inputs based on layout // Process inputs based on layout
Value X_processed, initial_h_processed; if (layout == 1) {
ValueTensorType yTy_processed, Y_hType_processed; X = StaticTranspose(b, X, 0, 1);
if (layout == 0) {
X_processed = X;
initial_h_processed = initial_h_forward;
yTy_processed = yTy;
Y_hType_processed = Y_hType;
} else {
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
initial_h_processed = b.create<AtenTransposeIntOp>(
initial_h.getType(), initial_h_forward, cstZero, cstOne);
auto yTySizes = yTy.getSizes();
auto Y_hTypeSizes = Y_hType.getSizes();
yTy_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
yTySizes[3]},
yTy.getDtype());
Y_hType_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
Y_hTypeSizes[2]},
Y_hType.getDtype());
} }
// Weights and biases ready. Calling GRU layer to insert the actual ops. // Weights and biases ready. Calling GRU layer to insert the actual ops.
GruLayerOutput gruLayerOutput = GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
gru_layer(b, X_processed, initial_h_processed, weights, activations, activations, linear_before_reset);
linear_before_reset);
// Process outputs based on layout // Process outputs based on layout
Value Y_final, Y_h_final; Value Y_final;
if (layout == 0) { if (binder.tensorResultTypeAtIndex(yTy, 0)) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne); Y_final = cstNone;
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
} else { } else {
auto Y_transposed = b.create<AtenTransposeIntOp>( if (layout == 0) {
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne); Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo); } else {
Type yTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
yTy.getDtype());
Y_final =
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
Y_final = StaticTranspose(b, Y_final, 1, 2);
Y_final = StaticTranspose(b, Y_final, 0, 1);
}
}
auto Y_h_transposed = b.create<AtenTransposeIntOp>( Value Y_h_final;
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne); if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero); Y_h_final = cstNone;
} else {
if (layout == 0) {
Y_h_final =
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
} else {
Type y_hTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
Y_hType.getDtype());
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
cstZero);
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
}
} }
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final}); rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});

View File

@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
template <typename OpTy, typename OpAdaptor> template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
int64_t &dim,
SmallVector<Value> &resultShape, SmallVector<Value> &resultShape,
SmallVector<Value> &offsets, SmallVector<Value> &offsets,
SmallVector<Value> &strides) { SmallVector<Value> &strides) {
@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1); Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant"); return op->emitError("unimplemented: dim is not constant");
@ -1658,10 +1658,17 @@ public:
if (!isValidDim(dim, inputRank)) if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
// TODO: Handle the case where the dim(th) dimension is dynamic. // assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) { if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure( Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
op, "unimplemented: dim(th) dimension is not expected to be dynamic"); Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
} }
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
@ -1671,7 +1678,7 @@ public:
// If the dim(th) dimension of operand tensor type is not statically unit, // If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation. // `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) { if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success(); return success();
} }
@ -1857,14 +1864,46 @@ public:
RankedTensorType resultType = cast<RankedTensorType>( RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType())); typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape; SmallVector<Value> resultShape, offsets, strides;
SmallVector<Value> offsets; int64_t dim;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp, if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>( AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) { op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure(); return failure();
} }
// If stride is negative, then flip the input tensor corresponding to that
// dim, update the stride for flipped tensor by multiplying it by -1, and
// update the offset as follows:
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
//
// For example:
// Input = [0, 1, 2, 3, 4, 5]
// stride = [-2], result_shape = [2], offset = [3]
// Result = [3, 1]
// After flipping:
// Input = [5, 4, 3, 2, 1, 0]
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
// Result = [3, 1]
Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
SmallVector<int64_t>{dim});
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, strides[dim], zero);
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
Value resShapeMulStride =
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
Value flippedOffset =
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
offsets[dim] = rewriter.create<arith::SelectOp>(
loc, isNegativeStride, flippedOffset, offsets[dim]);
input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
flippedInput, input);
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic); SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get( auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding()); dynShape, resultType.getElementType(), resultType.getEncoding());
@ -2095,12 +2134,11 @@ public:
RankedTensorType resultType = cast<RankedTensorType>( RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType())); typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape; SmallVector<Value> resultShape, offsets, strides;
SmallVector<Value> offsets; int64_t dim;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp, if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>( AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) { op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
return failure(); return failure();
} }
@ -2573,6 +2611,167 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenUnfoldOp : public OpConversionPattern<AtenUnfoldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto self = adaptor.getSelf();
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
int64_t dimension;
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) {
return rewriter.notifyMatchFailure(op,
"only support constant int dimension");
}
int64_t size;
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) {
return rewriter.notifyMatchFailure(op, "only support constant int size");
}
int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
return rewriter.notifyMatchFailure(op, "only support constant int step");
}
if (step <= 0) {
return rewriter.notifyMatchFailure(op, "step must be greater than zero.");
}
int64_t selfRank = selfType.getRank();
// Zero-Rank case
if (selfRank == 0) {
// Empty tensor
if (size == 0) {
RankedTensorType resultType =
RankedTensorType::get({0}, selfType.getElementType());
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType());
rewriter.replaceOp(op, emptyTensor);
return success();
}
Value unsqueezedSelf = rewriter.create<tensor::ExpandShapeOp>(
loc, RankedTensorType::get({1}, selfType.getElementType()), self,
ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, unsqueezedSelf);
return success();
}
auto shape = selfType.getShape();
if (dimension < 0) {
dimension = toPositiveDim(dimension, selfRank);
}
if (!isValidDim(dimension, selfRank)) {
return rewriter.notifyMatchFailure(op, "dimension out of range");
}
Value dimSize = rewriter.create<tensor::DimOp>(loc, self, dimension);
Value sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
Value sizeCheck = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, sizeValue, dimSize);
rewriter.create<cf::AssertOp>(
loc, sizeCheck,
rewriter.getStringAttr("size must be <= target dimension"));
/* Calculate output shape of unfold op:
* https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
* outputShape[dimension] is set to numBlocks, with size appended as an
* additional dimension
*/
SmallVector<OpFoldResult> outputShape;
for (int64_t i = 0; i < selfRank; i++) {
if (i == dimension) {
outputShape.push_back(getDynamicOrStaticNumBlocks(
rewriter, loc, shape[dimension], dimSize, size, step));
} else if (shape[i] == ShapedType::kDynamic) {
outputShape.push_back(
OpFoldResult(rewriter.create<tensor::DimOp>(loc, self, i)));
} else {
outputShape.push_back(rewriter.getIndexAttr(shape[i]));
}
}
outputShape.push_back(rewriter.getIndexAttr(size));
// Empty tensor to insert values into
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, outputShape, selfType.getElementType());
/**
* Use reindexing to map output indices to input indices
* i.e. In output of rank 3 case:
* (i, j, k) => (i', j') where i' = i * step + k and j' = j
* if dimension == 0
* (i, j, k) => (i', j') where i' = i and j' = j * step + k
* if dimension == 1
*/
MLIRContext *context = rewriter.getContext();
SmallVector<AffineExpr> outputExprs;
for (int dim = 0; dim < selfRank; ++dim) {
if (dim == dimension) {
auto idxLast = getAffineDimExpr(selfRank, context);
auto idxDimension = getAffineDimExpr(dimension, context);
AffineExpr dimIdx =
idxLast + idxDimension * rewriter.getAffineConstantExpr(step);
outputExprs.push_back(dimIdx);
} else {
outputExprs.push_back(getAffineDimExpr(dim, context));
}
}
int64_t outputRank = selfRank + 1;
auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context);
auto outputAffineMap =
AffineMap::getMultiDimIdentityMap(outputRank, context);
SmallVector<utils::IteratorType> iteratorTypes(
outputRank, utils::IteratorType::parallel);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, outputTensor.getType(), self, outputTensor,
ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
rewriter.replaceOp(op, result);
return success();
}
private:
OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc,
int64_t shapeDim, Value dimSize,
int64_t size, int64_t step) const {
/**
* numBlocks = (shape[dimension] - size) // step + 1
*/
if (shapeDim == ShapedType::kDynamic) {
Value numBlocksSubOp = rewriter.create<arith::SubIOp>(
loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, size));
Value numBlocksDivOp = rewriter.create<arith::DivUIOp>(
loc, numBlocksSubOp,
rewriter.create<arith::ConstantIndexOp>(loc, step));
Value numBlocks = rewriter.create<arith::AddIOp>(
loc, rewriter.create<arith::ConstantIndexOp>(loc, 1), numBlocksDivOp);
return OpFoldResult(numBlocks);
}
int64_t staticNumBlocks = (shapeDim - size) / step + 1;
return rewriter.getIndexAttr(staticNumBlocks); // Use static value
}
};
} // namespace
namespace { namespace {
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> { class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
public: public:
@ -2641,7 +2840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
/*benefit=*/200); /*benefit=*/200);
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context, patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
/*benefit=*/100); /*benefit=*/100);
target.addIllegalOp<AtenUnfoldOp>();
patterns.add<ConvertAtenUnfoldOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>(); target.addIllegalOp<AtenSqueezeDimOp>();

View File

@ -301,14 +301,9 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfRank = auto selfRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
Type elementType =
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<int64_t> axis; SmallVector<int64_t> axis;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
@ -321,40 +316,8 @@ public:
} }
} }
// Only used to calculate flipped values, i.e. those on the flip axes. Other Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, self.getType(), self, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(
loc, dims[flipDim], indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
return success(); return success();
} }
}; };
@ -1300,10 +1263,6 @@ public:
return success(); return success();
} }
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
// Special depthwise case: Cin = Cout = groups. // Special depthwise case: Cin = Cout = groups.
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
// of groups) to be depthwise in their documentation, but the linalg ops // of groups) to be depthwise in their documentation, but the linalg ops
@ -1315,21 +1274,45 @@ public:
if (inShape[1] == numGroups && weightShape[0] == numGroups && if (inShape[1] == numGroups && weightShape[0] == numGroups &&
weightShape[1] == 1) { weightShape[1] == 1) {
// Collapse weight shape (C/G == 1) // Collapse weight shape (C/G == 1)
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}}; SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1], SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
weightShape[2], weightShape[3]}; for (unsigned i = 0; i < numSpatialDims; i++) {
collapsedDims.push_back({i + 2});
collapsedShape.push_back(weightShape[i + 2]);
}
Type collapsedType = RankedTensorType::get( Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), weightDTy); makeShapeLLVMCompatible(collapsedShape), weightDTy);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>( Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims); loc, collapsedType, weight, collapsedDims);
if (!inputZp) { if (!inputZp) {
conv = rewriter switch (numSpatialDims) {
.create<linalg::DepthwiseConv2DNchwChwOp>( case 1:
loc, outputTensor.getType(), conv = rewriter
ValueRange{paddedInput, collapsedWeight}, outputTensor, .create<linalg::DepthwiseConv1DNcwCwOp>(
stridesAttr, dilationAttr) loc, outputTensor.getType(),
.getResult(0); ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 2:
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D depthwise convolution "
"supported for special case of group convolution");
};
} else { } else {
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D depthwise quantized convolution "
"supported for special case of group convolution");
// currently, the only named depthwise qconv op is nhwc_hwc // currently, the only named depthwise qconv op is nhwc_hwc
// input: nchw -> nhwc; weight (collapsed): chw -> hwc // input: nchw -> nhwc; weight (collapsed): chw -> hwc
// linalg conv result nhwc -> nchw // linalg conv result nhwc -> nchw
@ -1376,6 +1359,10 @@ public:
return success(); return success();
} }
if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
// Grouped case, use the grouped conv linalg op // Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) { auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = cast<RankedTensorType>(tensor.getType()); auto inType = cast<RankedTensorType>(tensor.getType());

View File

@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0)); b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero); return createEqual(b, loc, floatDtype, self, zero);
} }
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
auto ctype = cast<ComplexType>(
cast<RankedTensorType>(converter->convertType(complex.getType()))
.getElementType());
Type stype = ctype.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
}
if (isa<AtenAbsOp>(op)) { if (isa<AtenAbsOp>(op)) {
if (isa<IntegerType>(payloadArgs[0].getType())) if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]); return b.create<math::AbsIOp>(loc, payloadArgs[0]);
@ -1590,22 +1600,22 @@ public:
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op)) AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,

View File

@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
.getResult(0); .getResult(0);
return success(); return success();
} }
// Flips an input tensor based on the values of axis list.
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> axis) {
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();
// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
for (auto flipDim : axis)
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);
SmallVector<utils::IteratorType> iteratorTypes(selfRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
Value flipped =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(), input, initTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (auto i = 0; i < selfRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (auto flipDim : axis) {
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
indices[flipDim]);
}
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
return flipped;
}

View File

@ -325,7 +325,8 @@ public:
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
output = rewriter output = rewriter
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs, .create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr) dotDimensionNumbers, nullptr,
nullptr)
.getResult(); .getResult();
return success(); return success();
} }
@ -494,7 +495,7 @@ public:
/*lhsContractingDimensions=*/{lhsContractingDim}, /*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>( Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr);
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
if (!isa<Torch::NoneType>(biasTy)) { if (!isa<Torch::NoneType>(biasTy)) {

View File

@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
// Max pooling // Max pooling
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp, if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) { AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr; return nullptr;
} }
// AtenMaxPool1dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto outValTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
auto outIdxTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
if (inputRank <= 1) {
return op.emitError(
"max_pooling1d only supports inputs with rank higher than 1");
}
SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 1);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 1);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 1);
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
DenseI64ArrayAttr baseDilations;
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
// no need to reshape here for max_pool_1d. Need to make sure the iota
// dimension. dim=inputRank-2 or dim=inputRank-1?
auto indexTensor =
rewriter
.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
.getResult();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
// add block.
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *firstIdxArg = std::next(firstValArg);
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
// AtenMaxPool2dWithIndicesOp // AtenMaxPool2dWithIndicesOp
template <> template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ #define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options) patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
#undef INSERT_ATEN_POOLING_PATTERN #undef INSERT_ATEN_POOLING_PATTERN

View File

@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
} }
} }
if (isa<AtenAllOp>(op)) { if (isa<AtenAllOp, AtenAllDimOp>(op)) {
auto constAttr = auto constAttr =
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
AtenLinalgVectorNormOp>(op)) { AtenLinalgVectorNormOp>(op)) {
result = rewriter.create<stablehlo::AddOp>( result = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAllOp>(op)) { } else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
result = rewriter.create<stablehlo::AndOp>( result = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) { } else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \ patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
options) options)
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp); INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN #undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \ #define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \

View File

@ -161,12 +161,70 @@ public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp; using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
unsigned getBitWidth(Type type) const {
if (auto complexTy = dyn_cast<ComplexType>(type))
return 2 * getBitWidth(complexTy.getElementType());
return type.getIntOrFloatBitWidth();
}
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType()); auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!rankType) if (!rankType)
return op.emitError("Only ranked tensor types are currently supported"); return op.emitError("Only ranked tensor types are currently supported.");
auto loc = op.getLoc();
// support AtenViewDtypeOp
if (isa<AtenViewDtypeOp>(op)) {
auto self = adaptor.getSelf();
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
// infer the result shape
auto operandElt = rankType.getElementType();
auto targetElt = baseResultTy.getDtype();
auto operandEltBitWidth = getBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandSizes = rankType.getShape();
SmallVector<int64_t> castShape(operandSizes);
if (operandEltBitWidth > targetEltBitWidth) {
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
castShape.push_back(last_size);
} else if (operandEltBitWidth < targetEltBitWidth) {
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
if (!ShapedType::isDynamic(castShape.back()) and
last_size != castShape.back()) {
return rewriter.notifyMatchFailure(
op, "The last dim size is not equal to targetEltBitWidth / "
"operandEltBitWidth.");
} else {
castShape.pop_back();
}
}
auto resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
baseResultTy);
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "Currently only support static output shape.");
}
auto castType =
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
loc,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
castType),
self);
auto reshape =
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
rewriter.replaceOp(op, reshape);
return success();
}
// collect Value of dims // collect Value of dims
SmallVector<Value, 4> dimSizes; SmallVector<Value, 4> dimSizes;
@ -174,7 +232,6 @@ public:
return op.emitError("Dims size must be a list of Scalar"); return op.emitError("Dims size must be a list of Scalar");
} }
auto loc = op.getLoc();
if (dimSizes.size() == 0 || rankType.getRank() == 0) { if (dimSizes.size() == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, op,
@ -236,6 +293,13 @@ public:
SmallVector<Value, 4> &dimSizes) const; SmallVector<Value, 4> &dimSizes) const;
}; };
template <>
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return false;
}
template <> template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes( bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
#define INSERT_VIEW_OP_PATTERN(AtenOp) \ #define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options) patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN #undef INSERT_VIEW_OP_PATTERN

View File

@ -1497,6 +1497,79 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Type inputElementType =
cast<RankedTensorType>(input.getType()).getElementType();
// Converting the input element type to the result's element type.
// The only possible mismatch would be when the input element type is an
// integer but not `si64`. Therefore, we directly convert the input to
// `si64`. Rest all cases are handled in the dtype definition for this op.
if (elementType != inputElementType) {
Value torchInput = convertTensorToDtype(
rewriter, loc, op.getSelf(),
rewriter.getIntegerType(64, IntegerType::Signed));
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(torchInput.getType()),
torchInput);
}
int64_t inputRank = resultType.getRank();
Value dtype = op.getDtype();
if (!isa<Torch::NoneType>(dtype.getType()))
return rewriter.notifyMatchFailure(
op, "unsupported: dtype argument not supported");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "invalid dim");
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
SmallVector<Value> accSizes(sizes);
accSizes.erase(accSizes.begin() + dim);
SmallVector<int64_t> accStatic(
makeShapeTorchCompatible(resultType.getShape()));
accStatic.erase(accStatic.begin() + dim);
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
Type accType =
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
Value result = createTMTensorScanOp(
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
[](OpBuilder &b, Location loc, Value input, Value acc) {
Value prod =
(isa<mlir::FloatType>(input.getType())
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
b.create<TMTensor::YieldOp>(loc, prod);
});
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> { class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public: public:
@ -2240,6 +2313,8 @@ public:
patterns.add<ConvertAtenSortOp>(typeConverter, context); patterns.add<ConvertAtenSortOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>(); target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context); patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
target.addIllegalOp<AtenCumprodOp>();
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
target.addIllegalOp<AtenScaledDotProductAttentionOp>(); target.addIllegalOp<AtenScaledDotProductAttentionOp>();
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter, patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context); context);

View File

@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant"); "Unable to extract the scalar constant");
int64_t numElem = 1;
for (int64_t dim : dshape)
numElem *= dim;
if (isa<mlir::FloatType>(dtype)) { if (isa<mlir::FloatType>(dtype)) {
tosaTensor = tosa::getConstTensor<float>(rewriter, op, tosaTensor =
(isFloat ? doubleValue : intValue), tosa::getConstTensor<float>(
dshape, dtype) rewriter, op,
.value(); SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
dshape, dtype)
.value();
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) { } else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth(); auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64) if (w != 1 && w != 32 && w != 64)
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
} }
bool d = isFloat ? static_cast<bool>(doubleValue) bool d = isFloat ? static_cast<bool>(doubleValue)
: static_cast<bool>(intValue); : static_cast<bool>(intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<bool>(
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<bool>(numElem, d), dshape)
.value();
} else if (w == 32) { } else if (w == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
} }
int32_t d = isFloat ? static_cast<int32_t>(doubleValue) int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue); : static_cast<int32_t>(intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<int32_t>(
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
.value();
} else if (w == 64) { } else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
"of destination type"); "of destination type");
} }
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue); int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
tosaTensor = tosaTensor = tosa::getConstTensor<int64_t>(
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value(); rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
.value();
} }
} else { } else {
return rewriter.notifyMatchFailure(op, "Usupported element type"); return rewriter.notifyMatchFailure(op, "Usupported element type");
@ -891,8 +900,6 @@ public:
if (!result) if (!result)
return failure(); return failure();
// TBD - support dtype casting.
rewriter.replaceOp(op, {result.value()}); rewriter.replaceOp(op, {result.value()});
return success(); return success();
@ -2842,8 +2849,12 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Not all dims are valid"); return rewriter.notifyMatchFailure(op, "Not all dims are valid");
} }
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>( SmallVector<int32_t> dimListInt32;
rewriter, op.getOperation(), dimListInt, {selfRank}); for (auto v : dimListInt)
dimListInt32.push_back(v);
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
rewriter, op.getOperation(), dimListInt32, {selfRank});
rewriter.replaceOpWithNewOp<tosa::TransposeOp>( rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
@ -3819,6 +3830,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType inputs are currently supported");
auto index = adaptor.getIndex();
auto indexType = dyn_cast<RankedTensorType>(index.getType());
if (!indexType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType indices are currently supported");
auto inputShape = inputType.getShape();
int inputRank = inputType.getRank();
if (indexType.getRank() == 0)
return rewriter.notifyMatchFailure(
op, "Rank 0 index tensor is currently not supported");
// Dynamic shape check
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "AtenIndexSelectOp: support for dynamic input "
"shape not implemented");
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexType.getShape(),
rewriter.getIntegerType(32)),
index);
}
// Get positive dim
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Value `dim` should be a torch constant int");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
// Get the output type
auto outType = getTypeConverter()->convertType(op.getType());
// Reshape and expand the index tensor to have same rank and same dimensions
// (except for the targeted dim) as the input
//
// For example:
// Input shape = (4, 5, 6)
// Index vector shape = (2)
// Targeted dim = 1
// Reshaped and expanded index vector shape = (4, 2, 6)
//
// By reshaping and expanding the index vector, we can supply it into the
// gather op to mimic the functionality of aten.index_select
SmallVector<int64_t> indicesInputRankShape;
for (int64_t i = 0; i < inputRank; i++) {
if (i == dim) {
indicesInputRankShape.push_back(indexType.getShape()[0]);
} else {
indicesInputRankShape.push_back(1);
}
}
auto indicesInputRankType =
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
rewriter.getIntegerType(32));
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesInputRankType, index,
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
SmallVector<int64_t> tileShape(indicesInputRankShape);
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
for (int64_t i = 0; i < inputRank; i++) {
if (tileShape[i] == 1 && i != dim) {
tileShape[i] = inputShape[i];
expandedIndicesShape[i] = inputShape[i];
} else {
tileShape[i] = 1;
}
}
auto tileType =
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
rewriter.getIntegerType(32));
auto expandedIndices = rewriter.create<tosa::TileOp>(
op->getLoc(), tileType, reshapedIndices.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
// convert torch style index and dim into tf style indices
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
auto indicesTf = tosa::convertTorchIndexToTfIndices(
rewriter, op, input, expandedIndices.getResult(), dim);
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices failed");
// do the tf gathernd algorithm with tf style indices as input.
auto result =
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
if (!result) {
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
@ -5200,7 +5329,7 @@ public:
}; };
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> { class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
@ -5216,18 +5345,48 @@ public:
op, "Only Tensor types with static shapes are currently supported"); op, "Only Tensor types with static shapes are currently supported");
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
}
Value constOp;
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), constOp, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Supplied value must be a Scalar constant");
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp); Value fillValueTargetTensor;
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
// Reshape value tensor to have same rank and shape as input
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
auto fillValue = adaptor.getValue();
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
if (!fillValueType)
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
auto fillValueElemTy = fillValueType.getElementType();
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
auto fillValueMatchedInputRankType = RankedTensorType::get(
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
fillValueElemTy);
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), fillValueMatchedInputRankType, fillValue,
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
fillValueElemTy),
fillValueMatchedInputRankTensor.getResult(),
makeShapeTorchCompatible(outType.getShape()));
} else {
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
makeShapeTorchCompatible(outType.getShape()))))
return rewriter.notifyMatchFailure(
op, "Fill value must be a scalar constant");
}
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
fillValueTargetTensor);
return success(); return success();
} }
@ -5647,8 +5806,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return success(); return success();
} }
// Template to create support tril mask tensor for aten.tril // Template to create supporting tril mask tensor for aten.tril
// legalization
template <typename T> template <typename T>
Value createTrilMask(PatternRewriter &rewriter, Operation *op, Value createTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w, ArrayRef<int64_t> shape, int64_t h, int64_t w,
@ -5671,28 +5829,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value(); return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
} }
// Function to get tril mask tensor based on input type
// for aten.tril legalization
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal, Type type) {
return TypeSwitch<Type, Value>(type)
.Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
case 32:
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});
}
// Legalization for aten.tril // Legalization for aten.tril
template <> template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
@ -5740,14 +5876,31 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer"); return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
// Define shape for mask tensor based on rank // Define shape for mask tensor based on rank
SmallVector<int64_t> constShape; SmallVector<int64_t> maskShape;
for (auto i = 0; i < selfRank - 2; i++) for (auto i = 0; i < selfRank - 2; i++)
constShape.push_back(1); maskShape.push_back(1);
constShape.push_back(h); maskShape.push_back(h);
constShape.push_back(w); maskShape.push_back(w);
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal, Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
resultType.getElementType()); .Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, maskShape,
h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, maskShape,
h, w, diagonal);
case 32:
return createTrilMask<int32_t>(
rewriter, op, maskShape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(
rewriter, op, maskShape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask, rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
/*shift=*/0); /*shift=*/0);
@ -5755,6 +5908,311 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
return success(); return success();
} }
// Legalization for aten.flip
template <>
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
AtenFlipOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are currently supported");
SmallVector<int64_t> dims;
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(
op, "Only constant dims are currently supported");
auto selfRank = selfTy.getRank();
auto resultTy = getTypeConverter()->convertType(op.getType());
Value result = self;
for (auto &dim : dims) {
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
static_cast<int32_t>(dim));
}
rewriter.replaceOp(op, result);
return success();
}
// Legalization for aten.round:
// Rounds elements of input to the nearest integer.
// Implements "round half to even" to break ties when a number is equidistant
// from two integers.
template <>
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
AtenRoundOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To round to the nearest integer, we will consider the fractional part of
// the input element (= input element - integer part of element). If the
// fractional part is smaller than 0.5, round the number down. If the
// fractional part is 0.5, apply "round half to even" rule. If the fractional
// part is greater than 0.5, round up.
//
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
// res = floor(input)
// else:
// res = ceil(input)
auto self = adaptor.getSelf();
auto selfTy = dyn_cast<TensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
auto resultTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto boolTy =
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
auto resultElemTy = resultTy.getElementType();
auto oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
auto two =
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
auto floorInput =
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
// input - floor(input)
auto fractionalPart = rewriter.create<tosa::SubOp>(
op->getLoc(), resultTy, self, floorInput.getResult());
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
auto floorDivResult = rewriter.create<tosa::FloorOp>(
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
// (floor(input) // 2) * 2
auto evenComparison = rewriter.create<tosa::MulOp>(
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
auto floorInputEven = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
// (frac == 0.5) && (floor(input) % 2 == 0)
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
floorInputEven.getResult());
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
fracEqualOneHalfCond.getResult());
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
ceilInput.getResult());
return success();
}
// Template to create supporting diagonal mask tensor for aten.diagonal
template <typename T>
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t offset) {
SmallVector<T> vec;
for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) {
// Positive offset value moves above the main diagonal, while negative
// diagonal value moves below the main diagonal.
if (i + offset == j) {
vec.push_back(static_cast<T>(1));
} else {
vec.push_back(static_cast<T>(0));
}
}
}
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
}
// Legalization for aten.diagonal
template <>
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
AtenDiagonalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a ranked tensor type
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported");
// Rank below 2 not accepted
auto selfRank = selfType.getRank();
if (selfRank <= 1)
return rewriter.notifyMatchFailure(
op, "Rank 0 and 1 are not accepted as they cause underflow");
if (!selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Currently only static shapes are supported");
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!resultType)
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
auto selfElemTy = selfType.getElementType();
auto resultElemTy = resultType.getElementType();
int64_t offset, dim1, dim2;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
offset = 0;
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
dim1 = 0;
} else {
dim1 = toPositiveDim(dim1, selfRank);
}
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
dim2 = 1;
} else {
dim2 = toPositiveDim(dim2, selfRank);
}
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
int64_t h = selfShape[dim1];
int64_t w = selfShape[dim2];
// Overflowing offset not supported
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
return rewriter.notifyMatchFailure(
op, "Offset greater or equal than shape not supported");
int64_t targetDim1 = selfRank - 2;
int64_t targetDim2 = selfRank - 1;
Value selfTransposed = self;
SmallVector<int64_t> transposedInputShape = selfShape;
RankedTensorType transposedInputType = selfType;
// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
// we can consistently create the diagonal mask tensor.
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
SmallVector<int32_t> transposedDims;
transposedInputShape.clear();
for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim1 || i == dim2)
continue;
transposedDims.push_back(i);
}
transposedDims.push_back(dim1);
transposedDims.push_back(dim2);
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
rewriter, op,
/*vec=*/transposedDims,
/*shape=*/{static_cast<int32_t>(selfRank)});
for (auto &dim : transposedDims)
transposedInputShape.push_back(selfShape[dim]);
transposedInputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
selfTransposed = rewriter.create<tosa::TransposeOp>(
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
}
// Define shape for mask tensor based on rank
SmallVector<int64_t> maskShape;
for (auto i = 0; i < selfRank - 2; i++)
maskShape.push_back(1);
maskShape.push_back(h);
maskShape.push_back(w);
Value diagonalMask =
TypeSwitch<Type, Value>(resultElemTy)
.Case<mlir::FloatType>([&](auto) {
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
offset);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
offset);
case 32:
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
offset);
case 64:
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
offset);
}
llvm_unreachable("Invalid integer width");
});
Value diagonalTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
/*shift=*/0);
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
auto targetReduceDim = resultShape[resultType.getRank() - 1];
// If transposedInputShape[targetDim1] (or h) is greater than the innermost
// dim of the result, we won't get the correct shape when we reduce sum along
// the innermost dim to get the result. Therefore, we have to slice the
// transposed tensor so that transposedInputShape[targetDim1] ==
// targetReduceDim.
if (h > targetReduceDim) {
transposedInputShape[targetDim1] = targetReduceDim;
transposedInputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
SmallVector<int64_t> startSlice(selfRank, 0);
SmallVector<int64_t> sizeSlice =
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
if (offset < 0)
startSlice[targetDim1] = std::abs(offset);
diagonalTensor = rewriter.create<tosa::SliceOp>(
op->getLoc(), transposedInputType, diagonalTensor,
rewriter.getDenseI64ArrayAttr(startSlice),
rewriter.getDenseI64ArrayAttr(sizeSlice));
}
// Apply Reduce Sum to get the result
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
auto reduceDimAttr =
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
auto result =
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
reduceDimAttr, /*keep_dims=*/false);
rewriter.replaceOp(op, result.value());
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -5986,11 +6444,13 @@ public:
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN #undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ #define INSERT_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); INSERT_FILL_PATTERN(AtenFill_ScalarOp);
#undef INSERT_FILL_SCALAR_PATTERN INSERT_FILL_PATTERN(AtenFillScalarOp);
INSERT_FILL_PATTERN(AtenFillTensorOp);
#undef INSERT_FILL_PATTERN
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ #define INSERT_MASKED_FILL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
@ -6060,6 +6520,10 @@ public:
INSERT_ATENOP_PATTERN(AtenIscloseOp); INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -23,6 +23,15 @@ namespace tosa {
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// This function is a helper for `convertTorchIndexToTfIndices`.
//
// We convert PyTorch index to TensorFlow-style indices so that we can use
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
// and Scatter operators to TOSA using TensorFlow-style indices.
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
// to gather/scatter elements, while in TensorFlow, the indices point directly
// to positions that you want to gather/scatter elements.
std::optional<Value> std::optional<Value>
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indicesOneDimShape, int32_t dim, SmallVector<int64_t> indicesOneDimShape, int32_t dim,
@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
unsigned indexRank = indexShape.size(); unsigned indexRank = indexShape.size();
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
// Create torch.meshgrid inputs // Create torch.meshgrid inputs
// Example: indexShape=[1,4,2] // Example: indexShape=[1,4,2]
// dim0: indicesMetaElement = torch.arange(0, 1) = [0] // dim0: indicesMetaElement = torch.arange(0, 1) = [0]
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3] // dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1] // dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
for (int i = 0; i < indexShape[dim]; i++) { for (int i = 0; i < indexShape[dim]; i++)
indicesMetaElement.push_back(i); indicesMetaElement.push_back(i);
}
// Compute total number of meta element repeat times: int preDimMetaElementRepeatTimes = 1;
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim int postDimMetaElementRepeatTimes = 1;
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
for (int i = 0; i < static_cast<int>(indexRank); i++) {
if (i == dim) {
continue;
} else {
indicesMetaElementRepeatTimes *= indexShape[i];
}
}
if (dim != static_cast<int>(indexShape.size()) - 1) { // Compute total number of times meta element range should repeat
// Create one dim indices for index except for last dim // = product(indexShape[0:dim])
// Create indices raw vector. // dim0: preDimMetaElementRepeatTimes = 1
// torch.stack(torch.meshgrid) // dim1: preDimMetaElementRepeatTimes = 1
// dim0: indicesVec = [0 0 0 0 0 0 0 0] // dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
// dim0: indicesVec = [0 0 1 1 2 2 3 3] for (int i = 0; i < dim; i++)
preDimMetaElementRepeatTimes *= indexShape[i];
// Compute total number of times meta element repeat
// = product(indexShape[dim+1:indexRank])
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
// dim1: postDimMetaElementRepeatTimes = 2
// dim2: postDimMetaElementRepeatTimes = 1
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
postDimMetaElementRepeatTimes *= indexShape[i];
// Example using dim1:
// preDimMetaElementRepeatTimes = 1
// postDimMetaElementRepeatTimes = 2
// Using postDimMetaElementRepeatTimes, we get the meta element range:
// [0 0 1 1 2 2 3 3]
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
// [0 0 1 1 2 2 3 3]
//
// Let's use a clearer example:
// indexShape = [3, 4, 2]
// Target dim = 1
// => preDimMetaElementRepeatTimes = 3
// postDimMetaElementRepeatTimes = 2
// Using postDimMetaElementRepeatTimes, we get the meta element range:
// [0 0 1 1 2 2]
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
for (size_t elementId = 0; elementId < indicesMetaElement.size(); for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) { elementId++) {
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
indicesVec.push_back(indicesMetaElement[elementId]);
}
}
} else { // Create the one dim indices for last dim of index
// Create indices raw vector
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) {
indicesVec.push_back(indicesMetaElement[elementId]); indicesVec.push_back(indicesMetaElement[elementId]);
} }
} }

View File

@ -132,12 +132,28 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) { Type elemTy) {
Value initTensor = Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy); b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c0 = Type fillValElemTy = elemTy;
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType())); if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0); return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
} }
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
Type fillValElemTy = elemTy;
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
}
Value castIntToIndex(OpBuilder &b, Location loc, Value v) { Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(isa<IntegerType>(v.getType()) && "must be called with integer type"); assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v); return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);

View File

@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
} }
LogicalResult BindSymbolicShapeOp::verify() { LogicalResult BindSymbolicShapeOp::verify() {
if (getShapeSymbols().empty()) if (getShapeSymbols().size() !=
return emitOpError() << "requires non-empty shapeSymbols"; getShapeExpressions().getValue().getNumSymbols())
return emitOpError()
<< "requires equal number of shape symbol args and symbol args to "
"the attached affine map, since they are 1:1 mapped";
for (auto symbol : getShapeSymbols()) { for (auto symbol : getShapeSymbols()) {
Operation *definingOp = symbol.getDefiningOp(); Operation *definingOp = symbol.getDefiningOp();

View File

@ -9200,6 +9200,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
@ -10352,6 +10355,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %0 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n" " %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n" " return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
@ -11895,6 +11910,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %1 : !torch.int\n" " return %1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int4 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %2#1 : !torch.int\n"
" }\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
@ -14663,6 +14697,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"
@ -15601,6 +15639,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %2 : !torch.int\n" " return %2 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"size must be less than or equal to {}\"\n"
" %false = torch.constant.bool false\n"
" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %str_2 = torch.constant.str \"dimension out of range of {}\"\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n"
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n"
" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n"
" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" %14 = torch.aten.append.t %12, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %12 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
"}\n" "}\n"
""; "";
// clang-format on // clang-format on

View File

@ -7298,6 +7298,85 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
// op.
class DecomposeAtenAdaptiveMaxPool1dOp
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value input = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_max_pool1d must be a constant int");
}
SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});
Value dialationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
return success();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
@ -8720,6 +8799,77 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto self = op.getSelf();
auto target = op.getTarget();
auto posWeight = op.getPosWeight();
auto weight = op.getWeight();
auto reduction = op.getReduction();
Value loss;
auto one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto _one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
auto _target =
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
_target, one, one);
Value mm =
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
Value logSigm =
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);
if (!isa<Torch::NoneType>(posWeight.getType())) {
auto logWeight = rewriter.create<AtenAddScalarOp>(
loc, posWeight.getType(),
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
one, one),
one, one);
loss = rewriter.create<AtenSubTensorOp>(
loc, mm.getType(), mm,
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
logSigm),
one);
} else {
loss =
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
}
if (!isa<Torch::NoneType>(weight.getType())) {
loss =
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
}
// apply loss reduction.
int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
}
auto none = rewriter.create<ConstantNoneOp>(loc);
Value res;
if (reductionInt == 1) {
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
} else if (reductionInt == 2) {
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
} else {
res = loss;
}
rewriter.replaceOp(op, res);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> { class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern; using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
@ -9801,6 +9951,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
@ -9856,6 +10007,8 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);

View File

@ -530,11 +530,139 @@ public:
none, none, none, none); none, none, none, none);
return success(); return success();
} }
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
if (squeezeOp && resultTy.getSizes().size() == 1) {
rewriter.replaceOp(op, squeezeOp.getSelf());
return success();
}
return failure(); return failure();
} }
}; };
} // namespace } // namespace
namespace {
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
// [?,...,?,factor0,factor1] to unflatten, and views like
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
// possible to infer that all but last shared dim match
// TODO: move this to an actual canonicalizer for view after deleting the
// conflicting decompositions for flatten/unflatten -> view.
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
public:
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenViewOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> viewSizes;
if (failed(getListOperands(op.getSize(), viewSizes)))
return rewriter.notifyMatchFailure(
op, "view size must be from a list construct");
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasSizes() ||
resultTy.getSizes().size() != viewSizes.size())
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
int64_t inRank = selfTy.getSizes().size();
int64_t outRank = resultTy.getSizes().size();
SmallVector<int64_t> sizes(selfTy.getSizes());
int64_t endMatchingDim = -1;
// input sizes vs. provided view sizes comparison loop
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
int64_t providedSize;
bool providedStatic =
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
// if sizes[i] is static, it must match a constant in viewSizes[i]
if (sizes[i] != Torch::kUnknownSize) {
if (!providedStatic)
return rewriter.notifyMatchFailure(
op, "unsupported: found static input dim, but unable to match "
"provided view size on a constant. See position : " +
std::to_string(i));
if (providedSize != sizes[i]) {
endMatchingDim = i;
break;
}
continue;
}
// the remaining assumes sizes[i] is dynamic
// if provided dim is static, we can't verify it is a flatten/unflatten
// unless -1
if (i == outRank - 1 && providedStatic && providedSize == -1) {
endMatchingDim = i;
break;
}
if (providedStatic)
return rewriter.notifyMatchFailure(
op, "unexpected static view dim corresponding to dynamic input dim "
"at position : " +
std::to_string(i));
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, fail
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
return rewriter.notifyMatchFailure(
op, "expected dynamic view dim to come from a corresponding "
"size.int op. See position : " +
std::to_string(i));
int64_t dim;
// if the dim of the size int op doesn't match, fail
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != i)
return rewriter.notifyMatchFailure(
op,
"size int op dim cannot be matched to current dim at position : " +
std::to_string(i));
// passing the previous checks means viewSizes[i] = aten.size.int(self,
// i), so continue
}
// if all dims match and the ranks are equal, fold
if (endMatchingDim == -1 && inRank == outRank) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
if (endMatchingDim > -1 && inRank > outRank) {
// only support flattening last dim
if (endMatchingDim != outRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: output has more than back dim mismatching");
// flatten
Value start =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value end =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, resultTy, op.getSelf(), start, end);
return success();
}
if (endMatchingDim > -1 && inRank < outRank) {
// only support unflattening last dim
if (endMatchingDim != inRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input has more than back dim mismatching");
// unflatten
Value dim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value primList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(),
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, resultTy, op.getSelf(), dim, primList);
return success();
}
// examples that might reach this:
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
return rewriter.notifyMatchFailure(
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
", inRank=" + std::to_string(inRank) +
", outRank=" + std::to_string(outRank));
}
};
} // namespace
namespace { namespace {
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> { template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
public: public:
@ -561,18 +689,24 @@ public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
patterns patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern, PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern, FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern, FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
FoldAtenWhereSelf, RemoveUnusedPattern<Torch::AtenSizeIntOp>, RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>, RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>, RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>, RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>, RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>, RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context); RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
context->getLoadedDialect<mlir::arith::ArithDialect>() context->getLoadedDialect<mlir::arith::ArithDialect>()
->getCanonicalizationPatterns(patterns); ->getCanonicalizationPatterns(patterns);

View File

@ -90,7 +90,28 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Float8_e5m2fnuz; return torch_upstream::ScalarType::Float8_e5m2fnuz;
if (isa<Float8E4M3FNUZType>(type)) if (isa<Float8E4M3FNUZType>(type))
return torch_upstream::ScalarType::Float8_e4m3fnuz; return torch_upstream::ScalarType::Float8_e4m3fnuz;
llvm::report_fatal_error("unhandled type for getScalarTypeForType"); std::string errorMsg = "Unhandled type in getScalarTypeForType: ";
llvm::raw_string_ostream os(errorMsg);
type.print(os);
// os << "\nType ID: " << type.getTypeID();
os << "\nType properties:";
os << "\n Is integer: " << (type.isInteger() ? "yes" : "no");
os << "\n Is float: "
<< (type.isIntOrFloat() && !type.isInteger() ? "yes" : "no");
os << "\n Is index: " << (type.isIndex() ? "yes" : "no");
os << "\n Bit width: "
<< (type.isIntOrFloat() ? std::to_string(type.getIntOrFloatBitWidth())
: "N/A");
os << "\n Is signless: " << (type.isSignlessInteger() ? "yes" : "no");
os << "\n Is signed: " << (type.isSignedInteger() ? "yes" : "no");
// special error message for unsigned integer
if (type.isUnsignedInteger()) {
os << "\n Is unsigned: yes";
os << "\nUnsigned integer support is currently spotty. Please seeheck "
"https://github.com/llvm/torch-mlir/issues/3720 "
"for more details.";
}
llvm::report_fatal_error(llvm::StringRef(errorMsg));
} }
Type Torch::getTypeForTorchType( Type Torch::getTypeForTorchType(
MLIRContext *context, Type type, MLIRContext *context, Type type,
@ -257,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
AtenPixelShuffleOp, AtenDiagonalOp>(op); AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
} }
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,

View File

@ -79,6 +79,7 @@ TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors #### General TorchDynamo/PyTorch errors
# torch._dynamo.exc.Unsupported: Tensor.item # torch._dynamo.exc.Unsupported: Tensor.item
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
# RuntimeError: Failed running call_function aten.convolution_backward(... # RuntimeError: Failed running call_function aten.convolution_backward(...
# https://github.com/pytorch/pytorch/issues/89629 # https://github.com/pytorch/pytorch/issues/89629
@ -432,6 +433,7 @@ FX_IMPORTER_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
@ -504,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
"UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"ViewDtypeStaticModule_basic",
"WeightNormInterfaceModule_basic", "WeightNormInterfaceModule_basic",
# Error: `aten.as_strided` op is not supported # Error: `aten.as_strided` op is not supported
"ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackDynamic_Module_basic",
@ -588,6 +591,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AdaptiveAvgPool3dDynamic_basic", "AdaptiveAvgPool3dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic",
"AdaptiveMaxPool1dDynamic_basic", "AdaptiveMaxPool1dDynamic_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AdaptiveMaxPool1dStatic_basic", "AdaptiveMaxPool1dStatic_basic",
"AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicNoBatch_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic",
@ -666,6 +670,10 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
"CumsumModule_basic", "CumsumModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32", "DeterminantDynamicModule_F32",
@ -808,10 +816,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"RandnLikeDtypeModule_basic", "RandnLikeDtypeModule_basic",
"RandnLikeModule_basic", "RandnLikeModule_basic",
"RandnModule_basic", "RandnModule_basic",
"ReduceAllDimBool_basic",
"ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceProdDimIntFloatModule_basic", "ReduceProdDimIntFloatModule_basic",
"ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule2dInput_basic",
@ -829,18 +833,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ReplicationPad2dModule_top0", "ReplicationPad2dModule_top0",
"RsubInt0d_NumToTensor_Module_basic", "RsubInt0d_NumToTensor_Module_basic",
"ScalarImplicitFloatModule_basic", "ScalarImplicitFloatModule_basic",
# need aten.all.dim lowering to stablehlo
"SafeSoftmaxModule_basic",
"SafeSoftmaxNonNoneDtypeModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule", "ScatterReduceFloatMeanModule",
@ -926,6 +919,11 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"GCDBatchedModule_I32", "GCDBatchedModule_I32",
"GCDDynamicModule_I32", "GCDDynamicModule_I32",
"GCDModule_I32", "GCDModule_I32",
"Unfold_Module_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
} }
FX_IMPORTER_STABLEHLO_CRASHING_SET = { FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@ -1059,6 +1057,7 @@ STABLEHLO_PASS_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"ContiguousModule_basic", "ContiguousModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
@ -1079,6 +1078,9 @@ STABLEHLO_PASS_SET = {
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DetachModule_basic", "DetachModule_basic",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
@ -1425,6 +1427,7 @@ STABLEHLO_PASS_SET = {
"SliceSizeTwoStepModule_basic", "SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic", "SliceStartEqEndModule_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceWholeTensorModule_basic", "SliceWholeTensorModule_basic",
"SortIntListReverse_basic", "SortIntListReverse_basic",
"SortIntList_basic", "SortIntList_basic",
@ -1671,6 +1674,35 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"FlipModuleStaticShape_basic",
"FlipModule_basic",
"FlipNegativeIndexModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossFloat_basic",
"AtenLinalgCrossInt_basic",
"AtenLinalgCrossNegativeDim_basic",
"BinaryCrossEntropyWithLogitsStaticModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic", "ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic", "ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic",
@ -1814,7 +1846,6 @@ TOSA_PASS_SET = {
"ArangeStartOutModule_basic", "ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic", "ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic", "ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic", "ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic", "ArangeFloatModule_basic",
@ -2115,7 +2146,6 @@ TOSA_PASS_SET = {
"NormScalarOptDimModule_basic", "NormScalarOptDimModule_basic",
"NumToTensorFloatModule_basic", "NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic", "NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic", "NumpyTRank1Module_basic",
"NumpyTRank2Module_basic", "NumpyTRank2Module_basic",
"NumpyTRankNDynamicModule_basic", "NumpyTRankNDynamicModule_basic",
@ -2127,7 +2157,6 @@ TOSA_PASS_SET = {
"OnesModuleInt_basic", "OnesModuleInt_basic",
"PadModule_basic", "PadModule_basic",
"PadWithNoneValModule_basic", "PadWithNoneValModule_basic",
"Permute0RankModule_basic",
"PermuteModule_basic", "PermuteModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic", "PrimListUnpackNumMismatchModule_basic",
@ -2166,7 +2195,6 @@ TOSA_PASS_SET = {
"ScalarTensorInt64Module_basic", "ScalarTensorInt64Module_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic",
"SiluModule_basic", "SiluModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic", "SplitTensorLastSmallerModule_basic",
@ -2348,6 +2376,13 @@ MAKE_FX_TOSA_PASS_SET = (
} }
) - { ) - {
### Test failing in make_fx_tosa but not in tosa ### Test failing in make_fx_tosa but not in tosa
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
# Dynamic shape, has extra unsupported broadcast ops # Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d", "Matmul_3d",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
@ -2588,6 +2623,7 @@ ONNX_XFAIL_SET = {
"SliceCopyNegative_Module_basic", "SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic", "SliceCopyNonZeroDim_Module_basic",
"SliceCopy_Module_basic", "SliceCopy_Module_basic",
"SliceStaticComplexInputModule_basic",
"StdCorrectionLargeInputModule_basic", "StdCorrectionLargeInputModule_basic",
"TupleModule_basic", "TupleModule_basic",
"VarCorrectionLargeInputModule_basic", "VarCorrectionLargeInputModule_basic",
@ -2757,6 +2793,7 @@ ONNX_XFAIL_SET = {
"ElementwiseExpm1IntModule_basic", "ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic", "ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic", "ElementwiseFmodTensor_Int_basic",
"ElementwiseCreateComplexModule_basic",
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic", "ElementwiseOrTensorModule_basic",
@ -3071,7 +3108,6 @@ ONNX_XFAIL_SET = {
"ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf",
"ScatterValueFloatModule_basic", "ScatterValueFloatModule_basic",
"ScatterAddStaticModule_basic",
# Failure - onnx_lowering: onnx.ScatterND # Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic",
@ -3107,6 +3143,10 @@ ONNX_XFAIL_SET = {
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"ElementwiseAcosIntModule_basic", "ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic", "ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",
@ -3132,6 +3172,11 @@ ONNX_XFAIL_SET = {
"GCDBatchedModule_I32", "GCDBatchedModule_I32",
"GCDDynamicModule_I32", "GCDDynamicModule_I32",
"GCDModule_I32", "GCDModule_I32",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
"ViewDtypeStaticModule_basic",
} }
if torch_version_for_comparison() < version.parse("2.3.0.dev"): if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@ -3170,6 +3215,18 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
"AtenIntMM_basic", "AtenIntMM_basic",
} }
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"ElementwiseCreateComplexModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
}
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
"ElementwiseCreateComplexModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
}
ONNX_CRASHING_SET = LINALG_CRASHING_SET | { ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
"FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineModule_basic",
@ -3197,17 +3254,30 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
} }
FX_IMPORTER_TOSA_XFAIL_SET = { FX_IMPORTER_TOSA_XFAIL_SET = {
"ArangeZeroElementOutputModule_basic",
"NumpyTRank0Module_basic",
"Permute0RankModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"ElementwiseCreateComplexModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AtenPolarDoubleModule_basic", "AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic", "AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic", "HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic", "HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic", "HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic", "HstackBasicIntModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"AtenIntMM_basic", "AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic",
@ -3220,14 +3290,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"Conv_Transpose2dStaticModule_basic", "Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic", "Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic", "Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseRreluEvalModule_basic", "ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluTrainStaticModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic", "IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpool3dModulePad0_basic",
@ -3294,12 +3362,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenIntTensorCharDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic",
"AtenItemFpOpModule_basic", "AtenItemFpOpModule_basic",
"AtenItemIntOpModule_basic", "AtenItemIntOpModule_basic",
"AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossDynamic_basic",
"AtenLinalgCrossFloat_basic",
"AtenLinalgCrossInt_basic",
"AtenLinalgCrossNegativeDim_basic",
"AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic", "AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic", "AtenMatmulQint8MV_basic",
@ -3312,8 +3374,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenMmQuint8_basic", "AtenMmQuint8_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
@ -3355,6 +3415,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped", "Conv2dQInt8Module_grouped",
@ -3383,18 +3444,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"CumsumModule_basic", "CumsumModule_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32", "DeterminantDynamicModule_F32",
"DeterminantModule_F32", "DeterminantModule_F32",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
"DiagonalWithStaticShapeModule_basic",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
"DropoutTrainModule_basic", "DropoutTrainModule_basic",
@ -3478,20 +3535,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"EqIntModule_basic", "EqIntModule_basic",
"ExpandModule_basic", "ExpandModule_basic",
"ExponentialModule_basic", "ExponentialModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"Fill_TensorFloat64WithInt64_basic",
"FlipModuleStaticShape_basic",
"FlipModule_basic",
"FlipNegativeIndexModule_basic",
"FloatImplicitModule_basic", "FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic", "FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic", "FullLikeModuleInt3D_basic",
@ -3547,15 +3590,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectRank0IdxModule_basic", "IndexSelectRank0IdxModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
@ -3753,6 +3788,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"SignAndLogarithmOfDeterminantModule_F32", "SignAndLogarithmOfDeterminantModule_F32",
"SignAndLogarithmOfDeterminantBatchedModule_F32", "SignAndLogarithmOfDeterminantBatchedModule_F32",
"SignAndLogarithmOfDeterminantDynamicModule_F32", "SignAndLogarithmOfDeterminantDynamicModule_F32",
"SliceStaticComplexInputModule_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic", "SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic", "SliceCopyNonZeroDim_Module_basic",
@ -3808,11 +3844,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ToCopyWithDTypeModule_basic", "ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic", "TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic",
"TraceModule_basic",
"TraceModule_empty", "TraceModule_empty",
"TraceModule_nonsquare",
"TraceSignedIntModule_basic",
"TraceUnsignedIntModule_basic",
"TraceUnsignedIntModule_empty", "TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic", "TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic", "TypeConversionI1ToI32Module_basic",
@ -3833,9 +3865,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"VarMeanUnbiasedModule_basic", "VarMeanUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"ZeroFloat32Module_basic", "VisionTransformerModule_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_falsePinMemory", "ZerosLikeModule_falsePinMemory",
} }
@ -3848,6 +3878,15 @@ ONNX_TOSA_CRASHING_SET = {
} }
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"ArangeZeroElementOutputModule_basic",
"LinspaceEmptyModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"TrilIndicesAllZerosModule_basic",
"TriuIndicesAllZerosModule_basic",
"ElementwiseCreateComplexModule_basic",
"ReduceAllDimFloatModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic",
"HstackBasicComplexModule_basic", "HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic", "HstackBasicFloatModule_basic",
@ -3877,7 +3916,6 @@ ONNX_TOSA_XFAIL_SET = {
"Conv_Transpose2dStaticModule_basic", "Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic", "Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic", "Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"EinsumStaticModule_basic", "EinsumStaticModule_basic",
"ElementwiseFmaxModule_basic", "ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic", "ElementwiseFminModule_basic",
@ -4010,8 +4048,6 @@ ONNX_TOSA_XFAIL_SET = {
"AtenPolarDoubleModule_basic", "AtenPolarDoubleModule_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
@ -4055,8 +4091,6 @@ ONNX_TOSA_XFAIL_SET = {
"BucketizeTensorFloatModule_basic", "BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic", "BucketizeTensorModule_basic",
"BucketizeTensorOutInt32RightModule_basic", "BucketizeTensorOutInt32RightModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CeilFloatModule_basic", "CeilFloatModule_basic",
"ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic",
@ -4075,6 +4109,7 @@ ONNX_TOSA_XFAIL_SET = {
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
"Conv2dBiasNoPaddingModule_basic", "Conv2dBiasNoPaddingModule_basic",
"Conv2dModule_basic", "Conv2dModule_basic",
"Conv2dNoPaddingModule_basic", "Conv2dNoPaddingModule_basic",
@ -4115,6 +4150,10 @@ ONNX_TOSA_XFAIL_SET = {
"CumsumModule_basic", "CumsumModule_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic", "CumsumStaticNegativeDimModule_basic",
"CumprodModule_basic",
"CumprodInputDtypeInt32Module_basic",
"CumprodStaticModule_basic",
"CumprodStaticNegativeDimModule_basic",
"DeformConv2D_basic", "DeformConv2D_basic",
"DeterminantModule_F32", "DeterminantModule_F32",
"DeterminantBatchedModule_F32", "DeterminantBatchedModule_F32",
@ -4265,7 +4304,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseWhereSelfModule_basic", "ElementwiseWhereSelfModule_basic",
"EmbeddingModule1DIndices_basic", "EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic", "EmbeddingModuleF16_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic", "EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic", "EmbeddingModuleI64_basic",
"EmptyLikeMemoryFormatModule_basic", "EmptyLikeMemoryFormatModule_basic",
@ -4359,12 +4397,6 @@ ONNX_TOSA_XFAIL_SET = {
"IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic", "IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic", "IndexSelectDynamicModulebasic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinModule3dInput_basic", "IndexTensorHackedTwinModule3dInput_basic",
@ -4382,10 +4414,8 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic", "IndexTensorMultiInput_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexTensorSelectDimModule_basic", "IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
@ -4684,7 +4714,6 @@ ONNX_TOSA_XFAIL_SET = {
"ScatterValueFloatModule_basic", "ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
"SelectIntModule_basic", "SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertModule_basic", "SelectScattertModule_basic",
"SelectScattertStaticModule_basic", "SelectScattertStaticModule_basic",
"SignAndLogarithmOfDeterminantModule_F32", "SignAndLogarithmOfDeterminantModule_F32",
@ -4696,6 +4725,7 @@ ONNX_TOSA_XFAIL_SET = {
"SliceCopy_Module_basic", "SliceCopy_Module_basic",
"SliceEndSleStartModule_basic", "SliceEndSleStartModule_basic",
"SliceModule_basic", "SliceModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceNegIdxModule_basic", "SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic",

View File

@ -1445,6 +1445,9 @@ def atenmultinomial〡shape(self: List[int], num_samples: int, replacement: b
def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self return self
def atencumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self
def atenrand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: def atenrand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self return self
@ -2001,6 +2004,14 @@ def atenmse_loss〡shape(self: List[int], target: List[int], reduction: int =
def atencross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: def atencross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
def atenbinary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]:
scalar_shape: List[int] = []
if reduction == 0:
result_shape = upstream_shape_functions._copy(self)
else:
result_shape = scalar_shape
return result_shape
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
]) ])
@ -2937,6 +2948,18 @@ def atencumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
return torch.int64 return torch.int64
return self_dtype return self_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
def atencumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
if dtype is not None:
return dtype
self_rank, self_dtype = self_rank_dtype
if is_integer_dtype(self_dtype):
return torch.int64
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atendetach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atendetach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
@ -4954,6 +4977,10 @@ def atenlinalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return dtype return dtype
return atenstd〡dtype(self_rank_dtype) return atenstd〡dtype(self_rank_dtype)
def atenbinary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype( _check_tensors_with_the_same_dtype(
tensor_shapes=[(3,3)], tensor_shapes=[(3,3)],
@ -5543,7 +5570,45 @@ def aten_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int,
return torch.qint8 return torch.qint8
return torch.qint32 return torch.qint32
@check_shape_function([
Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero.
Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0.
Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case.
Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case.
Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension.
Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension.
])
def atenunfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]:
ndim = len(self)
# Rank zero tensor
if ndim == 0:
assert dimension == 0, f"dimension out of range of {ndim}"
assert size <= 1, "size must be less than or equal to 1"
return [size]
dim = dimension
if dim < 0:
dim += ndim
assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}"
size_dim = self[dim]
assert size <= size_dim, f"size must be less than or equal to {size_dim}"
num_blocks = (size_dim - size) // step + 1
out = upstream_shape_functions._copy(self)
out[dim] = num_blocks
out.append(size)
return out
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1)
)
def atenunfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

View File

@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit("aten::rad2deg : (Tensor) -> (Tensor)") emit("aten::rad2deg : (Tensor) -> (Tensor)")
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)")
@ -617,6 +618,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
) )
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit(
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit( emit(
@ -740,6 +744,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)" "aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
) )
emit(
"aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)"
)
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
@ -986,6 +993,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)") emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)") emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)") emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")

View File

@ -42,7 +42,7 @@ def import_onnx(contents):
# Import the ONNX model proto from the file contents: # Import the ONNX model proto from the file contents:
raw_model = onnx.load_from_string(contents) raw_model = onnx.load_from_string(contents)
# since it does not affect current e2e tests, data_prop is left false here # since it does not affect current e2e tests, data_prop is left false here
model_proto = onnx.shape_inference.infer_shapes(raw_model) model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)
# Import the ONNX module into an MLIR module: # Import the ONNX module into an MLIR module:
context = Context() context = Context()

View File

@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class CumprodModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, val):
ones = torch.ones([1], dtype=torch.int32)
return torch.ops.aten.cumprod(val, ones.item())
@register_test_case(module_factory=lambda: CumprodModule())
def CumprodModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.float32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, 1)
@register_test_case(module_factory=lambda: CumprodStaticModule())
def CumprodStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodStaticNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.float32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, dim=-1)
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumprodInputDtypeInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 7, 4], torch.int32, True),
]
)
def forward(self, val):
return torch.ops.aten.cumprod(val, 1)
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 7, 4).to(torch.int32))
# ==============================================================================
class AtenToDeviceModule(torch.nn.Module): class AtenToDeviceModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 4, 6], torch.float32, True),
([4, 1, 3], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv1d(
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
)
@register_test_case(
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
)
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
inputVec = tu.rand(2, 4, 6)
weight = torch.randn(4, 1, 3)
module.forward(inputVec, weight)
class Conv2dModule(torch.nn.Module): class Conv2dModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseCreateComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.complex(a, b)
@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.float32),
tu.randint(4, high=10).type(torch.float32),
)
# ==============================================================================
class ElementwiseMulTensorComplexModule(torch.nn.Module): class ElementwiseMulTensorComplexModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1783,6 +1783,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10)) module.forward(tu.rand(1, 512, 10))
class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
def __init__(self):
super().__init__()
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)
@export
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
def forward(self, x):
return self.amp1d(x)
@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7))
# AdaptiveMaxPool2d # AdaptiveMaxPool2d

View File

@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5)) module.forward(tu.rand(3, 4, 5))
class ReduceAllDimFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.all(a, dim=0)
@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ============================================================================== # ==============================================================================
@ -2294,6 +2314,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.randint(8, high=2)) module.forward(tu.rand(8, 2), tu.randint(8, high=2))
class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([8, 2], torch.float32, True),
([8, 2], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.binary_cross_entropy_with_logits(
input, target, reduction=0
)
@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule())
def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.rand(8, 2))
# ============================================================================== # ==============================================================================

View File

@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ViewDtypeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([12, 1], torch.float32, True),
]
)
def forward(self, a):
res = a.view(torch.int8)
return res
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(12, 1))
# ==============================================================================
class ReshapeAliasCollapseModule(torch.nn.Module): class ReshapeAliasCollapseModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1648,3 +1672,103 @@ class Rot90NegativeEvenRotationsModule(torch.nn.Module):
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule()) @register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils): def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 5, 1, 7, 3)) module.forward(tu.rand(6, 5, 1, 7, 3))
class Unfold_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([6, 4], torch.float32, True),
]
)
def forward(self, x):
return x.unfold(0, 2, 2)
@register_test_case(module_factory=lambda: Unfold_Module())
def Unfold_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))
class Unfold_Module_Negative_Dim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([6, 4, 4, 4], torch.float32, True),
]
)
def forward(self, x):
return x.unfold(-1, 2, 1)
@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim())
def Unfold_Module_Rank_4(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 4, 4))
class Unfold_Module_Rank_Zero(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([], torch.float32, True),
]
)
def forward(self, x):
return x.unfold(0, 1, 1)
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils):
module.forward(tu.rand())
class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([], torch.float32, True),
]
)
def forward(self, x):
return x.unfold(0, 0, 1)
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
module.forward(tu.rand())
class Unfold_Module_Dynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return x.unfold(1, 2, 1)
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 4, 4))

View File

@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class SliceStaticComplexInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([6, 4, 7], torch.complex64, True),
]
)
def forward(self, x):
return x[0:5:1, 1:3:1, 2:4:1]
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
# ==============================================================================
class SliceOutOfUpperBoundIndexModule(torch.nn.Module): class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -223,13 +223,13 @@ INSTALL_REQUIRES = [
EXT_MODULES = [ EXT_MODULES = [
CMakeExtension("torch_mlir._mlir_libs._torchMlir"), CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
] ]
NAME = "torch-mlir-core" NAME = "torch-mlir"
# If building PyTorch extensions, customize. # If building PyTorch extensions, customize.
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
import torch import torch
NAME = "torch-mlir" NAME = "torch-mlir-ext"
INSTALL_REQUIRES.extend( INSTALL_REQUIRES.extend(
[ [
f"torch=={torch.__version__}".split("+", 1)[0], f"torch=={torch.__version__}".split("+", 1)[0],

View File

@ -16,10 +16,71 @@
// CHECK-DAG: torch.prim.Loop.condition // CHECK-DAG: torch.prim.Loop.condition
// CHECK-DAG: } // CHECK-DAG: }
// CHECK: } // CHECK: }
module {
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none %none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
} }
// -----
// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias(
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>,
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>,
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>,
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>)
// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) {
// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK: torch.aten.flip
// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) {
// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK: torch.aten.flip
// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
// CHECK: }
func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
}
// -----
// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs(
// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>,
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>,
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>)
// CHECK: torch.aten.transpose.int
// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>):
// CHECK-DAG: torch.aten.select.int
// CHECK-DAG: torch.aten.linear
// CHECK-DAG: torch.aten.sigmoid
// CHECK-DAG: torch.aten.tanh
// CHECK-DAG: torch.prim.Loop.condition
// CHECK: }
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK-DAG: torch.aten.transpose.int
// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
// CHECK: }
func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
} }

View File

@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int // CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32> // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32>
@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int // CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
// CHECK-NEXT: torch.runtime.assert %[[GE]]
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
// CHECK: return %[[EXPAND]] // CHECK: return %[[EXPAND]]
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>

View File

@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[AXIS:.*]] = torch.constant.int 1 // CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1 // CHECK: %[[FIVE:.*]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
// CHECK: %[[STR:.*]] = torch.constant.str "add" // CHECK: %[[STR:.*]] = torch.constant.str "sum"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32>
} }
@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[AXIS:.*]] = torch.constant.int 1 // CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1 // CHECK: %[[FIVE:.*]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
// CHECK: %[[STR:.*]] = torch.constant.str "multiply" // CHECK: %[[STR:.*]] = torch.constant.str "prod"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32>
} }
@ -2833,6 +2835,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>)
return %0 : !torch.vtensor<[1],si64> return %0 : !torch.vtensor<[1],si64>
} }
// -----
// CHECK-LABEL: func.func @test_shape_scalar
func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64>
// CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64>
return %0: !torch.vtensor<[?],si64>
}
// ----- // -----

View File

@ -0,0 +1,17 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

View File

@ -696,8 +696,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
// CHECK: } // CHECK: }
@ -890,15 +890,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
// CHECK-LABEL: @torch.aten.max.dim$basic( // CHECK-LABEL: @torch.aten.max.dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>) // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> // CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true // CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 // CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> // CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
@ -1378,16 +1378,16 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
// CHECK-LABEL: func.func @torch.aten.min.dim$basic( // CHECK-LABEL: func.func @torch.aten.min.dim$basic(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> // CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true // CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> // CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> // CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32> // CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
// CHECK: } // CHECK: }
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
@ -1859,3 +1859,142 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0: !torch.vtensor<[?,?],si32> return %0: !torch.vtensor<[?,?],si32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
// CHECK: }
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
%dim1 = torch.constant.int 1
%dim2 = torch.constant.int 0
%offset = torch.constant.int -2
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
return %0 : !torch.vtensor<[5,6,2],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.index_select(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
// CHECK: }
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
%int2 = torch.constant.int 2
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
return %0 : !torch.vtensor<[4,5,2],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
// CHECK: }
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
return %0 : !torch.vtensor<[1,12,128,128],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.flip(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
// CHECK: }
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
return %1 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.round(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
// CHECK: }
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

View File

@ -4,17 +4,17 @@
// CHECK-LABEL: func.func @scan_1d_inclusive( // CHECK-LABEL: func.func @scan_1d_inclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) { // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): // CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: } // CHECK: }
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
@ -32,8 +32,10 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32> // CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32> // CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) { // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: } // CHECK: }
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
@ -62,14 +62,14 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> // CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> // CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func.func @scatter_update_scalar_1D( func.func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
@ -90,7 +90,8 @@ func.func @scatter_update_scalar_1D(
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> // CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> // CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D(
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32 // CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
// CHECK: tm_tensor.yield %[[ADD]] : i32 // CHECK: tm_tensor.yield %[[ADD]] : i32
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func.func @scatter_add_scalar_1D( func.func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,

View File

@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
// expected-error @+1 {{op requires non-empty shapeSymbols}} // expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32>
} }
// ----- // -----
// Verifier should not fail here since the op does not require shapeSymbols.
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}
// -----
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}} // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}

View File

@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> %slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
return %slice : !torch.vtensor<[2],si32> return %slice : !torch.vtensor<[2],si32>
} }
// -----
// CHECK-LABEL: @view_as_flatten_static
func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
%int1024 = torch.constant.int 1024
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list<int> -> !torch.vtensor<[?,?,1024],f32>
return %3 : !torch.vtensor<[?,?,1024],f32>
}
// -----
// CHECK-LABEL: @view_as_unflatten_static
func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
// CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
%int16 = torch.constant.int 16
%int64 = torch.constant.int 64
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
return %3 : !torch.vtensor<[?,?,16,64],f32>
}
// -----
// CHECK-LABEL: @view_as_flatten_dynamic
func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
return %3 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: @unsqueeze_squeeze_combo
func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int {
// CHECK: %int0 = torch.constant.int 0
// CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
// CHECK: return %0 : !torch.int
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
%4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
%7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%12 = torch.aten.cat %11, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
return %14 : !torch.int
}

View File

@ -34,5 +34,5 @@ def test_enable_ir_printing():
) )
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) # CHECK: // -----// IR Dump After Inliner (inline)
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { # CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {