mirror of https://github.com/llvm/torch-mlir
Merge branch 'main' into lower_torch_aten_gcd_to_linalg_and_scf
commit
ee7f6ee9fd
|
@ -50,7 +50,7 @@ TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}"
|
||||||
# Location to store Release wheels
|
# Location to store Release wheels
|
||||||
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
|
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
|
||||||
# What "packages to build"
|
# What "packages to build"
|
||||||
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}"
|
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-ext}"
|
||||||
# Use pre-built Pytorch
|
# Use pre-built Pytorch
|
||||||
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
||||||
# Skip running tests if you want quick iteration
|
# Skip running tests if you want quick iteration
|
||||||
|
@ -83,12 +83,12 @@ function run_on_host() {
|
||||||
fi
|
fi
|
||||||
mkdir -p "${TM_OUTPUT_DIR}"
|
mkdir -p "${TM_OUTPUT_DIR}"
|
||||||
case "$package" in
|
case "$package" in
|
||||||
torch-mlir)
|
torch-mlir-ext)
|
||||||
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
||||||
export USERID=0
|
export USERID=0
|
||||||
export GROUPID=0
|
export GROUPID=0
|
||||||
;;
|
;;
|
||||||
torch-mlir-core)
|
torch-mlir)
|
||||||
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
||||||
export USERID=0
|
export USERID=0
|
||||||
export GROUPID=0
|
export GROUPID=0
|
||||||
|
@ -158,22 +158,22 @@ function run_in_docker() {
|
||||||
export PATH=$python_dir/bin:$orig_path
|
export PATH=$python_dir/bin:$orig_path
|
||||||
echo ":::: Python version $(python3 --version)"
|
echo ":::: Python version $(python3 --version)"
|
||||||
case "$package" in
|
case "$package" in
|
||||||
torch-mlir)
|
torch-mlir-ext)
|
||||||
clean_wheels torch_mlir "$python_version"
|
clean_wheels torch_mlir_ext "$python_version"
|
||||||
build_torch_mlir "$TM_TORCH_VERSION"
|
build_torch_mlir_ext "$TM_TORCH_VERSION"
|
||||||
|
|
||||||
# Disable audit wheel until we can fix ODR torch issues. See
|
# Disable audit wheel until we can fix ODR torch issues. See
|
||||||
# https://github.com/llvm/torch-mlir/issues/1709
|
# https://github.com/llvm/torch-mlir/issues/1709
|
||||||
#
|
#
|
||||||
#run_audit_wheel torch_mlir "$python_version"
|
#run_audit_wheel torch_mlir_ext "$python_version"
|
||||||
|
|
||||||
clean_build torch_mlir "$python_version"
|
clean_build torch_mlir_ext "$python_version"
|
||||||
;;
|
;;
|
||||||
torch-mlir-core)
|
torch-mlir)
|
||||||
clean_wheels torch_mlir_core "$python_version"
|
clean_wheels torch_mlir "$python_version"
|
||||||
build_torch_mlir_core
|
build_torch_mlir
|
||||||
run_audit_wheel torch_mlir_core "$python_version"
|
run_audit_wheel torch_mlir "$python_version"
|
||||||
clean_build torch_mlir_core "$python_version"
|
clean_build torch_mlir "$python_version"
|
||||||
;;
|
;;
|
||||||
out-of-tree)
|
out-of-tree)
|
||||||
setup_venv "$python_version" "$TM_TORCH_VERSION"
|
setup_venv "$python_version" "$TM_TORCH_VERSION"
|
||||||
|
@ -431,7 +431,7 @@ function clean_build() {
|
||||||
rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch
|
rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_torch_mlir() {
|
function build_torch_mlir_ext() {
|
||||||
# Disable LTC build for releases
|
# Disable LTC build for releases
|
||||||
export TORCH_MLIR_ENABLE_LTC=0
|
export TORCH_MLIR_ENABLE_LTC=0
|
||||||
local torch_version="$1"
|
local torch_version="$1"
|
||||||
|
@ -470,7 +470,9 @@ function run_audit_wheel() {
|
||||||
rm "$generic_wheel"
|
rm "$generic_wheel"
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_torch_mlir_core() {
|
function build_torch_mlir() {
|
||||||
|
# Disable LTC build for releases
|
||||||
|
export TORCH_MLIR_ENABLE_LTC=0
|
||||||
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||||
CMAKE_GENERATOR=Ninja \
|
CMAKE_GENERATOR=Ninja \
|
||||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||||
|
|
|
@ -56,16 +56,16 @@ function run() {
|
||||||
export PATH=$python_dir/bin:$orig_path
|
export PATH=$python_dir/bin:$orig_path
|
||||||
echo ":::: Python version $(python3 --version)"
|
echo ":::: Python version $(python3 --version)"
|
||||||
case "$package" in
|
case "$package" in
|
||||||
|
torch-mlir-ext)
|
||||||
|
clean_wheels torch_mlir_ext "$python_version"
|
||||||
|
build_torch_mlir_ext torch_mlir_ext "$python_version"
|
||||||
|
run_audit_wheel torch_mlir_ext "$python_version"
|
||||||
|
;;
|
||||||
torch-mlir)
|
torch-mlir)
|
||||||
clean_wheels torch_mlir "$python_version"
|
clean_wheels torch_mlir "$python_version"
|
||||||
build_torch_mlir torch_mlir "$python_version"
|
build_torch_mlir torch_mlir "$python_version"
|
||||||
run_audit_wheel torch_mlir "$python_version"
|
run_audit_wheel torch_mlir "$python_version"
|
||||||
;;
|
;;
|
||||||
torch-mlir-core)
|
|
||||||
clean_wheels torch_mlir_core "$python_version"
|
|
||||||
build_torch_mlir_core torch_mlir_core "$python_version"
|
|
||||||
run_audit_wheel torch_mlir_core "$python_version"
|
|
||||||
;;
|
|
||||||
*)
|
*)
|
||||||
echo "Unrecognized package '$package'"
|
echo "Unrecognized package '$package'"
|
||||||
exit 1
|
exit 1
|
||||||
|
@ -75,7 +75,7 @@ function run() {
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_torch_mlir() {
|
function build_torch_mlir_ext() {
|
||||||
local wheel_basename="$1"
|
local wheel_basename="$1"
|
||||||
local python_version="$2"
|
local python_version="$2"
|
||||||
rm -rf "$output_dir"/build_venv
|
rm -rf "$output_dir"/build_venv
|
||||||
|
@ -93,7 +93,7 @@ function build_torch_mlir() {
|
||||||
rm -rf "$output_dir"/build_venv
|
rm -rf "$output_dir"/build_venv
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_torch_mlir_core() {
|
function build_torch_mlir() {
|
||||||
local wheel_basename="$1"
|
local wheel_basename="$1"
|
||||||
local python_version="$2"
|
local python_version="$2"
|
||||||
rm -rf "$output_dir"/build_venv
|
rm -rf "$output_dir"/build_venv
|
||||||
|
|
|
@ -14,7 +14,7 @@ While this is running, you can already setup the Python venv and dependencies in
|
||||||
## Setup your Python VirtualEnvironment and Dependencies
|
## Setup your Python VirtualEnvironment and Dependencies
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python -m venv mlir_venv
|
python3 -m venv mlir_venv
|
||||||
source mlir_venv/bin/activate
|
source mlir_venv/bin/activate
|
||||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit d418a03e01e6a31b51b0c9dd42ba46da6c47f89d
|
Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29
|
|
@ -1 +1 @@
|
||||||
Subproject commit c28d55e91b4a5daaff18a33ce7e9bbd0f171256a
|
Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739
|
|
@ -34,6 +34,7 @@ struct OpBinder {
|
||||||
Location getLoc() { return op->getLoc(); }
|
Location getLoc() { return op->getLoc(); }
|
||||||
|
|
||||||
int getNumOperands() { return op->getNumOperands(); }
|
int getNumOperands() { return op->getNumOperands(); }
|
||||||
|
int getNumResults() { return op->getNumResults(); }
|
||||||
|
|
||||||
// Operand matches of different arities.
|
// Operand matches of different arities.
|
||||||
ParseResult tensorOperand(Value &value0) {
|
ParseResult tensorOperand(Value &value0) {
|
||||||
|
@ -338,6 +339,31 @@ struct OpBinder {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
|
||||||
|
StringRef nameSuffix,
|
||||||
|
ArrayRef<float> defaults) {
|
||||||
|
SmallString<64> name("torch.onnx.");
|
||||||
|
name.append(nameSuffix);
|
||||||
|
auto attr = op->getAttr(name);
|
||||||
|
if (!attr) {
|
||||||
|
values.append(defaults.begin(), defaults.end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||||
|
for (auto element : arrayAttr) {
|
||||||
|
auto floatAttr = dyn_cast<FloatAttr>(element);
|
||||||
|
if (!floatAttr)
|
||||||
|
return failure();
|
||||||
|
FloatType t = cast<FloatType>(floatAttr.getType());
|
||||||
|
if (t.getWidth() != 32)
|
||||||
|
return failure();
|
||||||
|
values.push_back(floatAttr.getValue().convertToFloat());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
||||||
StringRef nameSuffix) {
|
StringRef nameSuffix) {
|
||||||
SmallString<64> name("torch.onnx.");
|
SmallString<64> name("torch.onnx.");
|
||||||
|
|
|
@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
|
||||||
Location loc, SmallVector<int64_t> dimensions,
|
Location loc, SmallVector<int64_t> dimensions,
|
||||||
Value input, Value &result);
|
Value input, Value &result);
|
||||||
|
|
||||||
|
// Flips an input tensor based on the values of axis list.
|
||||||
|
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
|
SmallVector<int64_t> axis);
|
||||||
|
|
||||||
} // namespace torch_to_linalg
|
} // namespace torch_to_linalg
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
|
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
|
||||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||||
|
|
|
@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
|
||||||
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
Type elemTy);
|
Type elemTy);
|
||||||
|
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
Type elemTy);
|
||||||
|
|
||||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
||||||
|
|
||||||
|
|
|
@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$real,
|
||||||
|
AnyTorchTensorType:$imag
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenComplexOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
|
@ -7078,6 +7102,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$kernel_size,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchListOfTorchIntType:$padding,
|
||||||
|
AnyTorchListOfTorchIntType:$dilation,
|
||||||
|
Torch_BoolType:$ceil_mode
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result0,
|
||||||
|
AnyTorchOptionalTensorType:$result1
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 2);
|
||||||
|
}
|
||||||
|
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 2);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -9195,6 +9248,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$target,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$pos_weight,
|
||||||
|
Torch_IntType:$reduction
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
|
}
|
||||||
|
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
|
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -13637,6 +13717,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dimension,
|
||||||
|
Torch_IntType:$size,
|
||||||
|
Torch_IntType:$step
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenUnfoldOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto shapeSizes = shapeType.getSizes();
|
auto shapeSizes = shapeType.getSizes();
|
||||||
int64_t dataRank = dataType.getSizes().size();
|
ArrayRef<int64_t> dataShape = dataType.getSizes();
|
||||||
|
int64_t dataRank = dataShape.size();
|
||||||
int64_t shapeRank = shapeSizes.size();
|
int64_t shapeRank = shapeSizes.size();
|
||||||
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
|
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
// we are using torch implementation Torch::AtenBroadcastToOp which
|
// we are using torch implementation Torch::AtenBroadcastToOp which
|
||||||
// takes list of int
|
// takes list of int
|
||||||
for (int i = 0; i < shapeSizes[0]; i++) {
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
||||||
|
// extract dim from shape
|
||||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getType<Torch::IntType>(),
|
loc, rewriter.getType<Torch::IntType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
loc, selectResultType, shape, zero, selectIndex);
|
loc, selectResultType, shape, zero, selectIndex);
|
||||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
Value selectDim = rewriter.create<Torch::AtenItemOp>(
|
||||||
loc, rewriter.getType<Torch::IntType>(), extract);
|
loc, rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
// compute dim to pass to broadcast op. For non-broadcastable dims,
|
||||||
if (i + rankDifference >= 0) {
|
// pass -1
|
||||||
|
Value dim;
|
||||||
|
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
|
||||||
|
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
|
||||||
|
// broadcasted
|
||||||
|
// 2. we will explicitly disallow broadcasting dynamic dims that are
|
||||||
|
// secretly 1.
|
||||||
|
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
|
||||||
|
// Assert dataShape[i + rankDiff] >= selectDim. If both are
|
||||||
|
// constant, this should fold out.
|
||||||
Value iv =
|
Value iv =
|
||||||
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
|
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
|
||||||
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
|
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
loc, rewriter.getType<Torch::IntType>(), data, iv);
|
loc, rewriter.getType<Torch::IntType>(), data, iv);
|
||||||
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
|
Value gtSelect =
|
||||||
|
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
|
||||||
|
rewriter.create<Torch::RuntimeAssertOp>(
|
||||||
|
loc, gtSelect,
|
||||||
|
rewriter.getStringAttr(
|
||||||
|
"onnx.Expand input has a dim that is not statically 1; "
|
||||||
|
"expected this dim >= dim provided shape."));
|
||||||
|
} else {
|
||||||
|
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
|
||||||
|
// dataRank)
|
||||||
|
// 2. selectDims which correspond to dataShape == 1 get included in
|
||||||
|
// broadcast
|
||||||
|
dim = selectDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
dimList.push_back(dim);
|
dimList.push_back(dim);
|
||||||
}
|
}
|
||||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
|
|
@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
// TODO: Implement max and min cases
|
// TODO: Implement max and min cases
|
||||||
if (reduction == "mul") {
|
if (reduction == "mul") {
|
||||||
reduction = "multiply";
|
reduction = "prod";
|
||||||
} else if (reduction == "max" || reduction == "min") {
|
} else if (reduction == "max" || reduction == "min") {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "max/min reduction unsupported for scatter elements");
|
binder.op, "max/min reduction unsupported for scatter elements");
|
||||||
|
} else if (reduction == "add") {
|
||||||
|
reduction = "sum";
|
||||||
}
|
}
|
||||||
|
|
||||||
Value cstStrReduction =
|
Value cstStrReduction =
|
||||||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
||||||
|
Value cstTrue =
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
|
||||||
binder.op, resultType, data, constAxis, indices, updates,
|
binder.op, resultType, data, constAxis, indices, updates,
|
||||||
cstStrReduction);
|
cstStrReduction, cstTrue);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
|
@ -1662,10 +1665,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
auto shapeType = Torch::ValueTensorType::get(
|
auto shapeType = Torch::ValueTensorType::get(
|
||||||
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
||||||
resultType.getOptionalDtype());
|
resultType.getOptionalDtype());
|
||||||
|
|
||||||
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
|
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
|
||||||
binder.getLoc(), shapeType, operand);
|
binder.getLoc(), shapeType, operand);
|
||||||
|
|
||||||
|
if (inputRank == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
|
||||||
|
binder.op, resultType, shape);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
if (start == 0 && end == -1) {
|
if (start == 0 && end == -1) {
|
||||||
rewriter.replaceOp(binder.op, shape);
|
rewriter.replaceOp(binder.op, shape);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1673,18 +1681,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
Value sv = rewriter.create<Torch::ConstantIntOp>(
|
Value sv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(start));
|
binder.getLoc(), rewriter.getI64IntegerAttr(start));
|
||||||
|
|
||||||
Value ev = rewriter.create<Torch::ConstantIntOp>(
|
Value ev = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(end));
|
binder.getLoc(), rewriter.getI64IntegerAttr(end));
|
||||||
|
|
||||||
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
|
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
|
||||||
|
|
||||||
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
|
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
|
||||||
|
|
||||||
shape = rewriter.create<Torch::AtenSliceTensorOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
|
||||||
binder.getLoc(), resultType, shape, dim, sv, ev, step);
|
binder.op, resultType, shape, dim, sv, ev, step);
|
||||||
|
|
||||||
rewriter.replaceOp(binder.op, shape);
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -4339,6 +4342,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
llvm::SmallVector<int64_t> ngram_counts;
|
llvm::SmallVector<int64_t> ngram_counts;
|
||||||
llvm::SmallVector<int64_t> ngram_indexes;
|
llvm::SmallVector<int64_t> ngram_indexes;
|
||||||
llvm::SmallVector<int64_t> pool_int64s;
|
llvm::SmallVector<int64_t> pool_int64s;
|
||||||
|
llvm::SmallVector<float> weights;
|
||||||
std::string mode;
|
std::string mode;
|
||||||
int64_t min_gram_length;
|
int64_t min_gram_length;
|
||||||
int64_t max_gram_length;
|
int64_t max_gram_length;
|
||||||
|
@ -4356,9 +4360,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (mode != "TF")
|
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
|
||||||
"TF mode supported only");
|
return failure();
|
||||||
|
|
||||||
if (pool_int64s.size() == 0)
|
if (pool_int64s.size() == 0)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "pool_int64s empty, only integers supported");
|
binder.op, "pool_int64s empty, only integers supported");
|
||||||
|
@ -4584,9 +4589,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||||
}
|
}
|
||||||
count = skipLoop.getResult(0);
|
count = skipLoop.getResult(0);
|
||||||
// insert count "tf" into output
|
|
||||||
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
||||||
binder.getLoc(), count);
|
binder.getLoc(), count);
|
||||||
|
if (mode == "IDF" || mode == "TFIDF") {
|
||||||
|
// both IDF and TFIDF modes use weights
|
||||||
|
float weight = weights[ngram_i];
|
||||||
|
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getF64FloatAttr(weight));
|
||||||
|
|
||||||
|
// TFIDF
|
||||||
|
Value multiplier = countFloat;
|
||||||
|
if (mode == "IDF") {
|
||||||
|
// All the counts larger than 1 would be truncated to 1
|
||||||
|
// and the i-th element in weights would be used to scale
|
||||||
|
// (by multiplication) the count of the i-th n-gram in pool.
|
||||||
|
|
||||||
|
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
|
||||||
|
binder.getLoc(), count);
|
||||||
|
// compare intCount > 0
|
||||||
|
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
|
||||||
|
binder.getLoc(), intCount, zero);
|
||||||
|
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
|
||||||
|
binder.getLoc(), gtZeroCount);
|
||||||
|
Value gtZeroCountFloat =
|
||||||
|
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
|
||||||
|
gtZeroCount);
|
||||||
|
multiplier = gtZeroCountFloat;
|
||||||
|
}
|
||||||
|
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
|
||||||
|
binder.getLoc(), multiplier, constWeight);
|
||||||
|
}
|
||||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
rewriter.getType<Torch::ListType>(
|
rewriter.getType<Torch::ListType>(
|
||||||
|
|
|
@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
std::string direction;
|
std::string direction;
|
||||||
|
|
||||||
ValueTensorType yTy, Y_hType, Y_cType;
|
ValueTensorType yTy, Y_hType, Y_cType;
|
||||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||||
binder.tensorResultTypeAtIndex(Y_hType, 1) ||
|
binder.tensorResultTypeAtIndex(Y_hType, 1) &&
|
||||||
binder.tensorResultTypeAtIndex(Y_cType, 2)) {
|
binder.tensorResultTypeAtIndex(Y_cType, 2)) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"At least one outputs must be present");
|
"At least one outputs must be present");
|
||||||
|
@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
|
|
||||||
auto xTy = cast<ValueTensorType>(X.getType());
|
auto xTy = cast<ValueTensorType>(X.getType());
|
||||||
auto wTy = cast<ValueTensorType>(W.getType());
|
auto wTy = cast<ValueTensorType>(W.getType());
|
||||||
Value B;
|
|
||||||
if (binder.tensorOperandAtIndex(B, 3)) {
|
// TODO: add defaults for activation_alpha acticvation_beta attributes
|
||||||
B = b.create<AtenZerosOp>(W.getType(), W);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<std::string> activationsList;
|
llvm::SmallVector<std::string> activationsList;
|
||||||
if (binder.stringArrayAttr(activationsList, "activations"))
|
if (binder.stringArrayAttr(activationsList, "activations"))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Missing required attribute; activations");
|
binder.op, "Missing required attribute; activations");
|
||||||
|
|
||||||
LstmActivations activations;
|
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
||||||
activations.f = "Sigmoid";
|
direction != "forward" && direction != "bidirectional")
|
||||||
activations.g = "Tanh";
|
|
||||||
activations.h = "Tanh";
|
|
||||||
if (activationsList.size() == 3) {
|
|
||||||
activations.f = activationsList[0];
|
|
||||||
activations.g = activationsList[1];
|
|
||||||
activations.h = activationsList[2];
|
|
||||||
} else if (activationsList.size() != 0) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "activations must be empty have 3 elements, but " +
|
binder.op, "Unsupported direction attribute value. "
|
||||||
|
"Only 'forward' / 'bidrectional' are supported but '" +
|
||||||
|
direction + "' is provided.");
|
||||||
|
int64_t num_directions = 1 + (direction == "bidirectional");
|
||||||
|
bool isBidirectional = direction == "bidirectional";
|
||||||
|
// There can be backward activations too
|
||||||
|
// if backward -> look for 6 atcivations (what happens when only three?)
|
||||||
|
|
||||||
|
int64_t num_activations = activationsList.size();
|
||||||
|
if (num_activations != 0 && num_activations != 3 && num_activations != 6) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "activations must either be empty (default), have 3 elements"
|
||||||
|
" (forward) or, have 6 elements (bidirectional), but " +
|
||||||
std::to_string(activationsList.size()) +
|
std::to_string(activationsList.size()) +
|
||||||
" are provided.");
|
" are provided.");
|
||||||
}
|
}
|
||||||
|
// TODO : Add checks, defaults and fails for inputs - sequence_lens, P and
|
||||||
|
// attrs- clip, input_forget, layout
|
||||||
|
|
||||||
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
Value B;
|
||||||
direction != "forward")
|
if (binder.tensorOperandAtIndex(B, 3)) {
|
||||||
|
Value none = b.create<ConstantNoneOp>();
|
||||||
|
Value cstHiddenx8 = b.create<ConstantIntOp>(
|
||||||
|
b.getType<IntType>(), b.getI64IntegerAttr(8 * hidden_size));
|
||||||
|
Value cstNumDir = b.create<ConstantIntOp>(
|
||||||
|
b.getType<IntType>(), b.getI64IntegerAttr(num_directions));
|
||||||
|
auto BType = b.getType<ValueTensorType>(
|
||||||
|
llvm::SmallVector<int64_t>{num_directions, 8 * hidden_size},
|
||||||
|
cast<ValueTensorType>(W.getType()).getDtype());
|
||||||
|
Value zerosShapeList = b.create<PrimListConstructOp>(
|
||||||
|
b.getType<ListType>(b.getType<IntType>()),
|
||||||
|
SmallVector<Value>{cstNumDir, cstHiddenx8});
|
||||||
|
B = b.create<AtenZerosOp>(BType, zerosShapeList, none, none, none, none);
|
||||||
|
}
|
||||||
|
|
||||||
|
LstmActivations activations, activationsRev;
|
||||||
|
// Default case (both forward and reverse)
|
||||||
|
activations.f = "Sigmoid";
|
||||||
|
activations.g = "Tanh";
|
||||||
|
activations.h = "Tanh";
|
||||||
|
activationsRev.f = "Sigmoid";
|
||||||
|
activationsRev.g = "Tanh";
|
||||||
|
activationsRev.h = "Tanh";
|
||||||
|
|
||||||
|
// forward only (also to be added for bidirectional case)
|
||||||
|
if (num_activations >= 3) {
|
||||||
|
activations.f = activationsList[0];
|
||||||
|
activations.g = activationsList[1];
|
||||||
|
activations.h = activationsList[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
// bidirectional
|
||||||
|
if (num_activations == 6) {
|
||||||
|
activationsRev.f = activationsList[3];
|
||||||
|
activationsRev.g = activationsList[4];
|
||||||
|
activationsRev.h = activationsList[5];
|
||||||
|
}
|
||||||
|
|
||||||
|
float clip;
|
||||||
|
if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0)
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"Unsupported direction attribute value. "
|
"clip attribute not supported");
|
||||||
"Only 'forward' is supported but '" +
|
|
||||||
direction + "' is provided.");
|
int64_t input_forget;
|
||||||
int64_t num_directions = 1 + (direction == "bidirectional");
|
if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) &&
|
||||||
|
input_forget != 0)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "only input_forget = 0 supported. Got input_forgt = " +
|
||||||
|
std::to_string(input_forget));
|
||||||
|
|
||||||
|
int64_t layout;
|
||||||
|
if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
|
||||||
|
std::to_string(layout));
|
||||||
|
|
||||||
auto XShape = xTy.getSizes();
|
auto XShape = xTy.getSizes();
|
||||||
int64_t batch_size = XShape[1];
|
int64_t seq_len, batch_size;
|
||||||
|
if (layout == 0) {
|
||||||
|
seq_len = XShape[0];
|
||||||
|
batch_size = XShape[1];
|
||||||
|
} else {
|
||||||
|
seq_len = XShape[1];
|
||||||
|
batch_size = XShape[0];
|
||||||
|
}
|
||||||
|
|
||||||
int64_t input_size = XShape[2];
|
int64_t input_size = XShape[2];
|
||||||
if (num_directions != wTy.getSizes()[0])
|
if (num_directions != wTy.getSizes()[0])
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
binder.op, "num_directions (" + std::to_string(num_directions) +
|
||||||
") does not match the first dimension of wTy (" +
|
") does not match the first dimension of wTy (" +
|
||||||
std::to_string(wTy.getSizes()[0]) + ")");
|
std::to_string(wTy.getSizes()[0]) + ")");
|
||||||
if (num_directions != 1)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
|
||||||
") is not equal to 1");
|
|
||||||
if (4 * hidden_size != wTy.getSizes()[1])
|
if (4 * hidden_size != wTy.getSizes()[1])
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
|
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
|
||||||
|
@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
Value R_forward = getDirection(b, 0, R);
|
Value R_forward = getDirection(b, 0, R);
|
||||||
Value B_forward = getDirection(b, 0, B);
|
Value B_forward = getDirection(b, 0, B);
|
||||||
|
|
||||||
|
Value W_reverse, R_reverse, B_reverse;
|
||||||
|
if (isBidirectional) {
|
||||||
|
W_reverse = getDirection(b, 1, W);
|
||||||
|
R_reverse = getDirection(b, 1, R);
|
||||||
|
B_reverse = getDirection(b, 1, B);
|
||||||
|
}
|
||||||
|
|
||||||
auto hTy = b.getType<ValueTensorType>(
|
auto hTy = b.getType<ValueTensorType>(
|
||||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||||
xTy.getDtype());
|
xTy.getDtype());
|
||||||
|
@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
|
|
||||||
Value initial_h;
|
Value initial_h;
|
||||||
if (binder.tensorOperandAtIndex(initial_h, 5)) {
|
if (binder.tensorOperandAtIndex(initial_h, 5)) {
|
||||||
|
// default created for layout 0
|
||||||
initial_h =
|
initial_h =
|
||||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||||
|
} else {
|
||||||
|
if (layout == 1)
|
||||||
|
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initial_c;
|
Value initial_c;
|
||||||
if (binder.tensorOperandAtIndex(initial_c, 6)) {
|
if (binder.tensorOperandAtIndex(initial_c, 6)) {
|
||||||
|
// default created for layout 0
|
||||||
initial_c =
|
initial_c =
|
||||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||||
|
} else {
|
||||||
|
if (layout == 1)
|
||||||
|
initial_c = StaticTranspose(b, initial_c, 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert X from layout 1 to layout 0
|
||||||
|
if (layout == 1)
|
||||||
|
X = StaticTranspose(b, X, 0, 1);
|
||||||
|
|
||||||
|
// X, initial_h, initial_c are now in layout 0
|
||||||
|
|
||||||
Value initial_h_forward = getDirection(b, 0, initial_h);
|
Value initial_h_forward = getDirection(b, 0, initial_h);
|
||||||
Value initial_c_forward = getDirection(b, 0, initial_c);
|
Value initial_c_forward = getDirection(b, 0, initial_c);
|
||||||
|
|
||||||
if (num_directions != 1) {
|
Value initial_h_reverse, initial_c_reverse;
|
||||||
return rewriter.notifyMatchFailure(
|
if (isBidirectional) {
|
||||||
binder.op, "Unsupported num_directions. Only 1 is supported but " +
|
initial_h_reverse = getDirection(b, 1, initial_h);
|
||||||
std::to_string(num_directions) + " is provided.");
|
initial_c_reverse = getDirection(b, 1, initial_c);
|
||||||
// TODO: support bidirectional LSTM by doing both directions and replacing
|
|
||||||
// Unsqueeze with Stack
|
|
||||||
}
|
}
|
||||||
// Everything hereon is for the forward direction, with the direction
|
|
||||||
// dimention squeezed out.
|
|
||||||
|
|
||||||
LstmWeights weights; // weights and biases
|
// Everything hereon is for the forward direction (unless in bidirectional if
|
||||||
|
// block), with the direction dimention squeezed out and all inputs in layout
|
||||||
|
// 0 format
|
||||||
|
|
||||||
|
LstmWeights weights, weightsRev; // weights and biases
|
||||||
|
|
||||||
auto intConst = [&](int64_t val) {
|
auto intConst = [&](int64_t val) {
|
||||||
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
|
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
|
||||||
|
@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
|
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
|
||||||
auto biasType = b.getType<ValueTensorType>(
|
auto biasType = b.getType<ValueTensorType>(
|
||||||
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
|
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
|
||||||
|
// forward
|
||||||
Value Wb = b.create<AtenSliceTensorOp>(biasType,
|
Value Wb = b.create<AtenSliceTensorOp>(biasType,
|
||||||
/*input=*/B_forward,
|
/*input=*/B_forward,
|
||||||
/*dim=*/cstZero,
|
/*dim=*/cstZero,
|
||||||
|
@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
/*start=*/recurrentWeightsStartIdx,
|
/*start=*/recurrentWeightsStartIdx,
|
||||||
/*end=*/recurrentWeightsEndIdx,
|
/*end=*/recurrentWeightsEndIdx,
|
||||||
/*step=*/cstOne);
|
/*step=*/cstOne);
|
||||||
|
Value Wb_reverse, Rb_reverse;
|
||||||
|
if (isBidirectional) {
|
||||||
|
// reverse
|
||||||
|
Wb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||||
|
/*input=*/B_reverse,
|
||||||
|
/*dim=*/cstZero,
|
||||||
|
/*start=*/cstZero,
|
||||||
|
/*end=*/inputWeightsEndIdx,
|
||||||
|
/*step=*/cstOne);
|
||||||
|
Rb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||||
|
/*input=*/B_reverse,
|
||||||
|
/*dim=*/cstZero,
|
||||||
|
/*start=*/recurrentWeightsStartIdx,
|
||||||
|
/*end=*/recurrentWeightsEndIdx,
|
||||||
|
/*step=*/cstOne);
|
||||||
|
}
|
||||||
|
|
||||||
// gate splitting
|
// gate splitting
|
||||||
auto gateBiasType = b.getType<ValueTensorType>(
|
auto gateBiasType = b.getType<ValueTensorType>(
|
||||||
|
@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
||||||
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
|
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
|
||||||
Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
|
Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
|
||||||
|
|
||||||
auto sliceIOFC = [&](std::function<Value(Value, Value)> slicerFunction) {
|
auto sliceIOFC = [&](std::function<Value(Value, Value, Value)> slicerFunction,
|
||||||
|
Value WoB) {
|
||||||
// slice into 4 components and return tuple
|
// slice into 4 components and return tuple
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
slicerFunction(cstZero, inputGateWeightsEndIdx),
|
slicerFunction(cstZero, inputGateWeightsEndIdx, WoB),
|
||||||
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx),
|
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB),
|
||||||
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx),
|
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB),
|
||||||
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx));
|
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto sliceGateBias = [&](Value startIdx, Value endIdx) {
|
auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||||
return b.create<AtenSliceTensorOp>(gateBiasType, Wb, cstZero, startIdx,
|
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||||
endIdx, cstOne);
|
endIdx, cstOne);
|
||||||
};
|
};
|
||||||
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
|
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
|
||||||
sliceIOFC(sliceGateBias);
|
sliceIOFC(sliceGateBias, Wb);
|
||||||
|
|
||||||
auto sliceGateBiasR = [&](Value startIdx, Value endIdx) {
|
if (isBidirectional)
|
||||||
return b.create<AtenSliceTensorOp>(gateBiasType, Rb, cstZero, startIdx,
|
std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f,
|
||||||
|
weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse);
|
||||||
|
|
||||||
|
auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||||
|
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||||
endIdx, cstOne);
|
endIdx, cstOne);
|
||||||
};
|
};
|
||||||
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
|
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
|
||||||
sliceIOFC(sliceGateBiasR);
|
sliceIOFC(sliceGateBiasR, Rb);
|
||||||
|
|
||||||
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) {
|
if (isBidirectional)
|
||||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, W_forward, cstZero,
|
std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f,
|
||||||
|
weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse);
|
||||||
|
|
||||||
|
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||||
|
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, WoB, cstZero,
|
||||||
startIdx, endIdx, cstOne);
|
startIdx, endIdx, cstOne);
|
||||||
};
|
};
|
||||||
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
|
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
|
||||||
sliceIOFC(sliceGateWeightsIH);
|
sliceIOFC(sliceGateWeightsIH, W_forward);
|
||||||
|
|
||||||
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) {
|
if (isBidirectional)
|
||||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, R_forward, cstZero,
|
std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) =
|
||||||
|
sliceIOFC(sliceGateWeightsIH, W_reverse);
|
||||||
|
|
||||||
|
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||||
|
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, WoB, cstZero,
|
||||||
startIdx, endIdx, cstOne);
|
startIdx, endIdx, cstOne);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
|
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
|
||||||
sliceIOFC(sliceGateWeightsHH);
|
sliceIOFC(sliceGateWeightsHH, R_forward);
|
||||||
|
|
||||||
|
if (isBidirectional)
|
||||||
|
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
|
||||||
|
sliceIOFC(sliceGateWeightsHH, R_reverse);
|
||||||
|
|
||||||
LstmLayerOutput lstmLayerOutput = lstm_layer(
|
LstmLayerOutput lstmLayerOutput = lstm_layer(
|
||||||
b, X, initial_h_forward, initial_c_forward, weights, activations);
|
b, X, initial_h_forward, initial_c_forward, weights, activations);
|
||||||
|
|
||||||
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
|
Value Y_h_result, Y_c_result, Y_result;
|
||||||
|
|
||||||
|
// if forward (unidirectional) unsqueeze and output
|
||||||
|
auto YallDtype =
|
||||||
|
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
|
||||||
|
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
|
||||||
|
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
|
||||||
|
auto Y_uni_type = b.getType<ValueTensorType>(
|
||||||
|
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||||
|
YallDtype);
|
||||||
|
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
|
||||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||||
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype());
|
YallDtype);
|
||||||
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(
|
auto Y_res_type = b.getType<ValueTensorType>(
|
||||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero);
|
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
|
||||||
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>(
|
hidden_size},
|
||||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero);
|
YallDtype);
|
||||||
|
|
||||||
|
Value Y_h_forward =
|
||||||
|
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
|
||||||
|
|
||||||
|
Value Y_c_forward =
|
||||||
|
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero);
|
||||||
|
|
||||||
// unsqueeze num_directions dim1 of Y
|
// unsqueeze num_directions dim1 of Y
|
||||||
// to create the onnx.LSTM output shape [seq_length, num_directions,
|
// to create the onnx.LSTM output shape [seq_length, num_directions,
|
||||||
// batch_size, hidden_size]
|
// batch_size, hidden_size]
|
||||||
Value Y_unsqueezed =
|
Value Y_forward =
|
||||||
b.create<AtenUnsqueezeOp>(yTy, lstmLayerOutput.Y, cstOne);
|
b.create<AtenUnsqueezeOp>(Y_uni_type, lstmLayerOutput.Y, cstOne);
|
||||||
|
|
||||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed,
|
Y_result = Y_forward;
|
||||||
Y_c_unsqueezed});
|
Y_h_result = Y_h_forward;
|
||||||
|
Y_c_result = Y_c_forward;
|
||||||
|
|
||||||
|
// add bidrectional reverse layer
|
||||||
|
// this is just flip X, lstm layer, flip results, stack
|
||||||
|
// flip X
|
||||||
|
Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped,
|
||||||
|
Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list;
|
||||||
|
LstmLayerOutput revLstmLayerOutput;
|
||||||
|
if (isBidirectional) {
|
||||||
|
dim0 = b.create<PrimListConstructOp>(b.getType<ListType>(intType),
|
||||||
|
SmallVector<Value>{cstZero});
|
||||||
|
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
|
||||||
|
revLstmLayerOutput =
|
||||||
|
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
|
||||||
|
weightsRev, activationsRev);
|
||||||
|
|
||||||
|
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
|
||||||
|
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||||
|
revLstmLayerOutput.Y_h, cstZero);
|
||||||
|
Y_c_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||||
|
revLstmLayerOutput.Y_c, cstZero);
|
||||||
|
Y_reverse_unflipped =
|
||||||
|
b.create<AtenUnsqueezeOp>(Y_uni_type, revLstmLayerOutput.Y, cstOne);
|
||||||
|
|
||||||
|
// flip Y_rev on dim 0 [seq_len]
|
||||||
|
Y_reverse = b.create<AtenFlipOp>(Y_uni_type, Y_reverse_unflipped, dim0);
|
||||||
|
|
||||||
|
// Concat forward and reverse results on dim 1
|
||||||
|
Y_output_list =
|
||||||
|
b.create<PrimListConstructOp>(b.getType<ListType>(Y_uni_type),
|
||||||
|
SmallVector<Value>{Y_forward, Y_reverse});
|
||||||
|
Y_result = b.create<AtenCatOp>(Y_res_type, Y_output_list, cstOne);
|
||||||
|
|
||||||
|
// Concat forward and reverse results on dim 0
|
||||||
|
Y_h_output_list = b.create<PrimListConstructOp>(
|
||||||
|
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||||
|
SmallVector<Value>{Y_h_forward, Y_h_reverse});
|
||||||
|
Y_h_result =
|
||||||
|
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_h_output_list, cstZero);
|
||||||
|
|
||||||
|
Y_c_output_list = b.create<PrimListConstructOp>(
|
||||||
|
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||||
|
SmallVector<Value>{Y_c_forward, Y_c_reverse});
|
||||||
|
Y_c_result =
|
||||||
|
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_c_output_list, cstZero);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layout == 1) {
|
||||||
|
// Update Y, Y_h, Y_c results to layout 1
|
||||||
|
Y_result = StaticTranspose(b, Y_result, 1, 2);
|
||||||
|
Y_result = StaticTranspose(b, Y_result, 0, 1);
|
||||||
|
Y_h_result = StaticTranspose(b, Y_h_result, 0, 1);
|
||||||
|
Y_c_result = StaticTranspose(b, Y_c_result, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only add outputs specified in onnx output node
|
||||||
|
SmallVector<Value> actualOutputs = {Y_result, Y_h_result, Y_c_result},
|
||||||
|
outputs;
|
||||||
|
ValueTensorType resTy;
|
||||||
|
for (int i = 0; i < binder.getNumResults(); ++i) {
|
||||||
|
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
|
||||||
|
outputs.push_back(cstNone);
|
||||||
|
} else {
|
||||||
|
outputs.push_back(actualOutputs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, outputs);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1072,11 +1273,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
||||||
Value cstNone = b.create<ConstantNoneOp>();
|
Value cstNone = b.create<ConstantNoneOp>();
|
||||||
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
|
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
|
||||||
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
|
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
|
||||||
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));
|
|
||||||
|
|
||||||
// Binding arguments
|
// Binding arguments
|
||||||
ValueTensorType yTy, Y_hType;
|
ValueTensorType yTy, Y_hType;
|
||||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||||
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"At least one output must be present");
|
"At least one output must be present");
|
||||||
|
@ -1132,6 +1332,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
||||||
// Validations
|
// Validations
|
||||||
auto XShape = xTy.getSizes();
|
auto XShape = xTy.getSizes();
|
||||||
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
|
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
|
||||||
|
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
|
||||||
int64_t input_size = XShape[2];
|
int64_t input_size = XShape[2];
|
||||||
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@ -1173,6 +1374,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
||||||
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||||
initial_h =
|
initial_h =
|
||||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||||
|
} else {
|
||||||
|
if (layout == 1) {
|
||||||
|
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (binder.tensorOperandAtIndex(sequence_lens, 4))
|
if (binder.tensorOperandAtIndex(sequence_lens, 4))
|
||||||
|
@ -1192,10 +1397,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
||||||
// fill in B
|
// fill in B
|
||||||
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||||
if (B == nullptr) {
|
if (B == nullptr) {
|
||||||
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
|
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
|
||||||
SmallVector<Value> BShapeListContents = {
|
SmallVector<Value> BShapeListContents = {
|
||||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
|
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
|
||||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
|
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
|
||||||
Value BShapeList = b.create<PrimListConstructOp>(
|
Value BShapeList = b.create<PrimListConstructOp>(
|
||||||
b.getType<ListType>(intType), BShapeListContents);
|
b.getType<ListType>(intType), BShapeListContents);
|
||||||
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
|
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
|
||||||
|
@ -1256,51 +1461,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
||||||
B_slices[4], B_slices[5]);
|
B_slices[4], B_slices[5]);
|
||||||
|
|
||||||
// Process inputs based on layout
|
// Process inputs based on layout
|
||||||
Value X_processed, initial_h_processed;
|
if (layout == 1) {
|
||||||
ValueTensorType yTy_processed, Y_hType_processed;
|
X = StaticTranspose(b, X, 0, 1);
|
||||||
|
|
||||||
if (layout == 0) {
|
|
||||||
X_processed = X;
|
|
||||||
initial_h_processed = initial_h_forward;
|
|
||||||
yTy_processed = yTy;
|
|
||||||
Y_hType_processed = Y_hType;
|
|
||||||
} else {
|
|
||||||
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
|
|
||||||
initial_h_processed = b.create<AtenTransposeIntOp>(
|
|
||||||
initial_h.getType(), initial_h_forward, cstZero, cstOne);
|
|
||||||
|
|
||||||
auto yTySizes = yTy.getSizes();
|
|
||||||
auto Y_hTypeSizes = Y_hType.getSizes();
|
|
||||||
|
|
||||||
yTy_processed = b.getType<ValueTensorType>(
|
|
||||||
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
|
|
||||||
yTySizes[3]},
|
|
||||||
yTy.getDtype());
|
|
||||||
|
|
||||||
Y_hType_processed = b.getType<ValueTensorType>(
|
|
||||||
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
|
|
||||||
Y_hTypeSizes[2]},
|
|
||||||
Y_hType.getDtype());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weights and biases ready. Calling GRU layer to insert the actual ops.
|
// Weights and biases ready. Calling GRU layer to insert the actual ops.
|
||||||
GruLayerOutput gruLayerOutput =
|
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
|
||||||
gru_layer(b, X_processed, initial_h_processed, weights, activations,
|
activations, linear_before_reset);
|
||||||
linear_before_reset);
|
|
||||||
|
|
||||||
// Process outputs based on layout
|
// Process outputs based on layout
|
||||||
Value Y_final, Y_h_final;
|
Value Y_final;
|
||||||
|
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
|
||||||
|
Y_final = cstNone;
|
||||||
|
} else {
|
||||||
if (layout == 0) {
|
if (layout == 0) {
|
||||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
|
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
|
||||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
|
||||||
} else {
|
} else {
|
||||||
auto Y_transposed = b.create<AtenTransposeIntOp>(
|
Type yTy_original = b.getType<ValueTensorType>(
|
||||||
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
|
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
|
yTy.getDtype());
|
||||||
|
Y_final =
|
||||||
|
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
|
||||||
|
Y_final = StaticTranspose(b, Y_final, 1, 2);
|
||||||
|
Y_final = StaticTranspose(b, Y_final, 0, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto Y_h_transposed = b.create<AtenTransposeIntOp>(
|
Value Y_h_final;
|
||||||
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
|
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
|
Y_h_final = cstNone;
|
||||||
|
} else {
|
||||||
|
if (layout == 0) {
|
||||||
|
Y_h_final =
|
||||||
|
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
||||||
|
} else {
|
||||||
|
Type y_hTy_original = b.getType<ValueTensorType>(
|
||||||
|
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
|
||||||
|
Y_hType.getDtype());
|
||||||
|
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
|
||||||
|
cstZero);
|
||||||
|
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
|
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
|
||||||
|
|
|
@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
|
||||||
template <typename OpTy, typename OpAdaptor>
|
template <typename OpTy, typename OpAdaptor>
|
||||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
|
int64_t &dim,
|
||||||
SmallVector<Value> &resultShape,
|
SmallVector<Value> &resultShape,
|
||||||
SmallVector<Value> &offsets,
|
SmallVector<Value> &offsets,
|
||||||
SmallVector<Value> &strides) {
|
SmallVector<Value> &strides) {
|
||||||
|
@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
||||||
|
|
||||||
int64_t dim;
|
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("unimplemented: dim is not constant");
|
return op->emitError("unimplemented: dim is not constant");
|
||||||
|
|
||||||
|
@ -1658,10 +1658,17 @@ public:
|
||||||
if (!isValidDim(dim, inputRank))
|
if (!isValidDim(dim, inputRank))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
||||||
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
// assert dynamic squeeze dim size == 1
|
||||||
if (inputType.isDynamicDim(dim)) {
|
if (inputType.isDynamicDim(dim)) {
|
||||||
return rewriter.notifyMatchFailure(
|
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
|
||||||
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
|
||||||
|
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
|
||||||
|
Value cmp = rewriter.create<arith::CmpIOp>(
|
||||||
|
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
|
||||||
|
rewriter.create<cf::AssertOp>(
|
||||||
|
op.getLoc(), cmp,
|
||||||
|
rewriter.getStringAttr(
|
||||||
|
"Expected dynamic squeeze dim size to be statically 1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
@ -1671,7 +1678,7 @@ public:
|
||||||
|
|
||||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||||
// `aten.squeeze` will behave as an identity operation.
|
// `aten.squeeze` will behave as an identity operation.
|
||||||
if (inputType.getDimSize(dim) != 1) {
|
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1857,14 +1864,46 @@ public:
|
||||||
RankedTensorType resultType = cast<RankedTensorType>(
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType()));
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape, offsets, strides;
|
||||||
SmallVector<Value> offsets;
|
int64_t dim;
|
||||||
SmallVector<Value> strides;
|
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
||||||
AtenSliceTensorOpAdaptor>(
|
AtenSliceTensorOpAdaptor>(
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If stride is negative, then flip the input tensor corresponding to that
|
||||||
|
// dim, update the stride for flipped tensor by multiplying it by -1, and
|
||||||
|
// update the offset as follows:
|
||||||
|
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// Input = [0, 1, 2, 3, 4, 5]
|
||||||
|
// stride = [-2], result_shape = [2], offset = [3]
|
||||||
|
// Result = [3, 1]
|
||||||
|
// After flipping:
|
||||||
|
// Input = [5, 4, 3, 2, 1, 0]
|
||||||
|
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
|
||||||
|
// Result = [3, 1]
|
||||||
|
|
||||||
|
Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
|
||||||
|
SmallVector<int64_t>{dim});
|
||||||
|
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
|
||||||
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, strides[dim], zero);
|
||||||
|
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
|
||||||
|
Value resShapeMulStride =
|
||||||
|
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
|
||||||
|
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
|
||||||
|
Value flippedOffset =
|
||||||
|
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
|
||||||
|
offsets[dim] = rewriter.create<arith::SelectOp>(
|
||||||
|
loc, isNegativeStride, flippedOffset, offsets[dim]);
|
||||||
|
|
||||||
|
input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
|
||||||
|
flippedInput, input);
|
||||||
|
|
||||||
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
||||||
auto sliceType = RankedTensorType::get(
|
auto sliceType = RankedTensorType::get(
|
||||||
dynShape, resultType.getElementType(), resultType.getEncoding());
|
dynShape, resultType.getElementType(), resultType.getEncoding());
|
||||||
|
@ -2095,12 +2134,11 @@ public:
|
||||||
RankedTensorType resultType = cast<RankedTensorType>(
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType()));
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape, offsets, strides;
|
||||||
SmallVector<Value> offsets;
|
int64_t dim;
|
||||||
SmallVector<Value> strides;
|
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||||
AtenSliceScatterOpAdaptor>(
|
AtenSliceScatterOpAdaptor>(
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2573,6 +2611,167 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenUnfoldOp : public OpConversionPattern<AtenUnfoldOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
|
int64_t dimension;
|
||||||
|
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"only support constant int dimension");
|
||||||
|
}
|
||||||
|
int64_t size;
|
||||||
|
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only support constant int size");
|
||||||
|
}
|
||||||
|
int64_t step;
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only support constant int step");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (step <= 0) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "step must be greater than zero.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t selfRank = selfType.getRank();
|
||||||
|
|
||||||
|
// Zero-Rank case
|
||||||
|
if (selfRank == 0) {
|
||||||
|
// Empty tensor
|
||||||
|
if (size == 0) {
|
||||||
|
RankedTensorType resultType =
|
||||||
|
RankedTensorType::get({0}, selfType.getElementType());
|
||||||
|
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, resultType.getShape(), resultType.getElementType());
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, emptyTensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value unsqueezedSelf = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, RankedTensorType::get({1}, selfType.getElementType()), self,
|
||||||
|
ArrayRef<ReassociationIndices>{});
|
||||||
|
rewriter.replaceOp(op, unsqueezedSelf);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = selfType.getShape();
|
||||||
|
|
||||||
|
if (dimension < 0) {
|
||||||
|
dimension = toPositiveDim(dimension, selfRank);
|
||||||
|
}
|
||||||
|
if (!isValidDim(dimension, selfRank)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dimension out of range");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value dimSize = rewriter.create<tensor::DimOp>(loc, self, dimension);
|
||||||
|
|
||||||
|
Value sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
|
||||||
|
Value sizeCheck = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::ule, sizeValue, dimSize);
|
||||||
|
rewriter.create<cf::AssertOp>(
|
||||||
|
loc, sizeCheck,
|
||||||
|
rewriter.getStringAttr("size must be <= target dimension"));
|
||||||
|
|
||||||
|
/* Calculate output shape of unfold op:
|
||||||
|
* https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
|
||||||
|
* outputShape[dimension] is set to numBlocks, with size appended as an
|
||||||
|
* additional dimension
|
||||||
|
*/
|
||||||
|
SmallVector<OpFoldResult> outputShape;
|
||||||
|
for (int64_t i = 0; i < selfRank; i++) {
|
||||||
|
if (i == dimension) {
|
||||||
|
outputShape.push_back(getDynamicOrStaticNumBlocks(
|
||||||
|
rewriter, loc, shape[dimension], dimSize, size, step));
|
||||||
|
} else if (shape[i] == ShapedType::kDynamic) {
|
||||||
|
outputShape.push_back(
|
||||||
|
OpFoldResult(rewriter.create<tensor::DimOp>(loc, self, i)));
|
||||||
|
} else {
|
||||||
|
outputShape.push_back(rewriter.getIndexAttr(shape[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputShape.push_back(rewriter.getIndexAttr(size));
|
||||||
|
|
||||||
|
// Empty tensor to insert values into
|
||||||
|
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, outputShape, selfType.getElementType());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use reindexing to map output indices to input indices
|
||||||
|
* i.e. In output of rank 3 case:
|
||||||
|
* (i, j, k) => (i', j') where i' = i * step + k and j' = j
|
||||||
|
* if dimension == 0
|
||||||
|
* (i, j, k) => (i', j') where i' = i and j' = j * step + k
|
||||||
|
* if dimension == 1
|
||||||
|
*/
|
||||||
|
MLIRContext *context = rewriter.getContext();
|
||||||
|
SmallVector<AffineExpr> outputExprs;
|
||||||
|
for (int dim = 0; dim < selfRank; ++dim) {
|
||||||
|
if (dim == dimension) {
|
||||||
|
auto idxLast = getAffineDimExpr(selfRank, context);
|
||||||
|
auto idxDimension = getAffineDimExpr(dimension, context);
|
||||||
|
|
||||||
|
AffineExpr dimIdx =
|
||||||
|
idxLast + idxDimension * rewriter.getAffineConstantExpr(step);
|
||||||
|
outputExprs.push_back(dimIdx);
|
||||||
|
} else {
|
||||||
|
outputExprs.push_back(getAffineDimExpr(dim, context));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t outputRank = selfRank + 1;
|
||||||
|
auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context);
|
||||||
|
auto outputAffineMap =
|
||||||
|
AffineMap::getMultiDimIdentityMap(outputRank, context);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
outputRank, utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
Value result =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, outputTensor.getType(), self, outputTensor,
|
||||||
|
ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes,
|
||||||
|
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc,
|
||||||
|
int64_t shapeDim, Value dimSize,
|
||||||
|
int64_t size, int64_t step) const {
|
||||||
|
/**
|
||||||
|
* numBlocks = (shape[dimension] - size) // step + 1
|
||||||
|
*/
|
||||||
|
if (shapeDim == ShapedType::kDynamic) {
|
||||||
|
Value numBlocksSubOp = rewriter.create<arith::SubIOp>(
|
||||||
|
loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, size));
|
||||||
|
Value numBlocksDivOp = rewriter.create<arith::DivUIOp>(
|
||||||
|
loc, numBlocksSubOp,
|
||||||
|
rewriter.create<arith::ConstantIndexOp>(loc, step));
|
||||||
|
Value numBlocks = rewriter.create<arith::AddIOp>(
|
||||||
|
loc, rewriter.create<arith::ConstantIndexOp>(loc, 1), numBlocksDivOp);
|
||||||
|
return OpFoldResult(numBlocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t staticNumBlocks = (shapeDim - size) / step + 1;
|
||||||
|
return rewriter.getIndexAttr(staticNumBlocks); // Use static value
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
|
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -2641,7 +2840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
/*benefit=*/200);
|
/*benefit=*/200);
|
||||||
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
||||||
/*benefit=*/100);
|
/*benefit=*/100);
|
||||||
|
target.addIllegalOp<AtenUnfoldOp>();
|
||||||
|
patterns.add<ConvertAtenUnfoldOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSqueezeOp>();
|
target.addIllegalOp<AtenSqueezeOp>();
|
||||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||||
|
|
|
@ -301,14 +301,9 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfRank =
|
auto selfRank =
|
||||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
Type elementType =
|
|
||||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
|
||||||
Value c1 =
|
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
||||||
|
|
||||||
SmallVector<int64_t> axis;
|
SmallVector<int64_t> axis;
|
||||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
||||||
|
@ -321,40 +316,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
|
||||||
// dims won't be used.
|
|
||||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
|
||||||
for (auto flipDim : axis)
|
|
||||||
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
|
||||||
|
|
||||||
Value initTensor = createZeroInitTensor(
|
|
||||||
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
|
|
||||||
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
|
||||||
selfRank, utils::IteratorType::parallel);
|
|
||||||
SmallVector<AffineMap> indexingMaps(
|
|
||||||
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
|
|
||||||
Value flipped =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, self.getType(), self, initTensor, indexingMaps,
|
|
||||||
iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
SmallVector<Value> indices;
|
|
||||||
for (auto i = 0; i < selfRank; i++)
|
|
||||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
||||||
for (auto flipDim : axis) {
|
|
||||||
indices[flipDim] = b.create<arith::SubIOp>(
|
|
||||||
loc, dims[flipDim], indices[flipDim]);
|
|
||||||
}
|
|
||||||
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
|
|
||||||
.getResult();
|
|
||||||
b.create<linalg::YieldOp>(loc, res);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1300,10 +1263,6 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numSpatialDims != 2)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only 2D grouped convolution supported");
|
|
||||||
|
|
||||||
// Special depthwise case: Cin = Cout = groups.
|
// Special depthwise case: Cin = Cout = groups.
|
||||||
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
||||||
// of groups) to be depthwise in their documentation, but the linalg ops
|
// of groups) to be depthwise in their documentation, but the linalg ops
|
||||||
|
@ -1315,21 +1274,45 @@ public:
|
||||||
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
||||||
weightShape[1] == 1) {
|
weightShape[1] == 1) {
|
||||||
// Collapse weight shape (C/G == 1)
|
// Collapse weight shape (C/G == 1)
|
||||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
|
||||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
|
||||||
weightShape[2], weightShape[3]};
|
for (unsigned i = 0; i < numSpatialDims; i++) {
|
||||||
|
collapsedDims.push_back({i + 2});
|
||||||
|
collapsedShape.push_back(weightShape[i + 2]);
|
||||||
|
}
|
||||||
Type collapsedType = RankedTensorType::get(
|
Type collapsedType = RankedTensorType::get(
|
||||||
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
||||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
loc, collapsedType, weight, collapsedDims);
|
loc, collapsedType, weight, collapsedDims);
|
||||||
if (!inputZp) {
|
if (!inputZp) {
|
||||||
|
switch (numSpatialDims) {
|
||||||
|
case 1:
|
||||||
|
conv = rewriter
|
||||||
|
.create<linalg::DepthwiseConv1DNcwCwOp>(
|
||||||
|
loc, outputTensor.getType(),
|
||||||
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||||
|
stridesAttr, dilationAttr)
|
||||||
|
.getResult(0);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||||
loc, outputTensor.getType(),
|
loc, outputTensor.getType(),
|
||||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||||
stridesAttr, dilationAttr)
|
stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 1D and 2D depthwise convolution "
|
||||||
|
"supported for special case of group convolution");
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
|
if (numSpatialDims != 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 2D depthwise quantized convolution "
|
||||||
|
"supported for special case of group convolution");
|
||||||
|
|
||||||
// currently, the only named depthwise qconv op is nhwc_hwc
|
// currently, the only named depthwise qconv op is nhwc_hwc
|
||||||
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
||||||
// linalg conv result nhwc -> nchw
|
// linalg conv result nhwc -> nchw
|
||||||
|
@ -1376,6 +1359,10 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (numSpatialDims != 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only 2D grouped convolution supported");
|
||||||
|
|
||||||
// Grouped case, use the grouped conv linalg op
|
// Grouped case, use the grouped conv linalg op
|
||||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||||
|
|
|
@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||||
return createEqual(b, loc, floatDtype, self, zero);
|
return createEqual(b, loc, floatDtype, self, zero);
|
||||||
}
|
}
|
||||||
|
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
|
||||||
|
auto ctype = cast<ComplexType>(
|
||||||
|
cast<RankedTensorType>(converter->convertType(complex.getType()))
|
||||||
|
.getElementType());
|
||||||
|
Type stype = ctype.getElementType();
|
||||||
|
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
|
||||||
|
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
|
||||||
|
}
|
||||||
if (isa<AtenAbsOp>(op)) {
|
if (isa<AtenAbsOp>(op)) {
|
||||||
if (isa<IntegerType>(payloadArgs[0].getType()))
|
if (isa<IntegerType>(payloadArgs[0].getType()))
|
||||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||||
|
@ -1590,22 +1600,22 @@ public:
|
||||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
|
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
|
||||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||||
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
|
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
||||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
|
||||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||||
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
|
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||||
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
|
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||||
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
|
@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
||||||
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
||||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
||||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
|
||||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||||
|
|
|
@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flips an input tensor based on the values of axis list.
|
||||||
|
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value input, SmallVector<int64_t> axis) {
|
||||||
|
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||||
|
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
|
|
||||||
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||||
|
// dims won't be used.
|
||||||
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||||
|
for (auto flipDim : axis)
|
||||||
|
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
||||||
|
|
||||||
|
Value initTensor = createZeroInitTensor(
|
||||||
|
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(selfRank,
|
||||||
|
utils::IteratorType::parallel);
|
||||||
|
SmallVector<AffineMap> indexingMaps(
|
||||||
|
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
|
||||||
|
Value flipped =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, input.getType(), input, initTensor, indexingMaps,
|
||||||
|
iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (auto i = 0; i < selfRank; i++)
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
for (auto flipDim : axis) {
|
||||||
|
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
|
||||||
|
indices[flipDim]);
|
||||||
|
}
|
||||||
|
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
|
||||||
|
.getResult();
|
||||||
|
b.create<linalg::YieldOp>(loc, res);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
return flipped;
|
||||||
|
}
|
||||||
|
|
|
@ -325,7 +325,8 @@ public:
|
||||||
lhsContractingDim, rhsContractingDim);
|
lhsContractingDim, rhsContractingDim);
|
||||||
output = rewriter
|
output = rewriter
|
||||||
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||||
dotDimensionNumbers, nullptr)
|
dotDimensionNumbers, nullptr,
|
||||||
|
nullptr)
|
||||||
.getResult();
|
.getResult();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -494,7 +495,7 @@ public:
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr);
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
if (!isa<Torch::NoneType>(biasTy)) {
|
if (!isa<Torch::NoneType>(biasTy)) {
|
||||||
|
|
|
@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
|
|
||||||
// Max pooling
|
// Max pooling
|
||||||
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||||
AtenMaxPool2dWithIndicesOp>(op)) {
|
AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||||
if (isa<mlir::FloatType>(elementTy)) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenMaxPool1dWithIndicesOp
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
|
||||||
|
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
auto inputShape = inputTy.getShape();
|
||||||
|
auto inputRank = inputTy.getRank();
|
||||||
|
|
||||||
|
auto outValTy =
|
||||||
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||||
|
auto outIdxTy =
|
||||||
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||||
|
|
||||||
|
if (inputRank <= 1) {
|
||||||
|
return op.emitError(
|
||||||
|
"max_pooling1d only supports inputs with rank higher than 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
|
||||||
|
bool ceilMode = false;
|
||||||
|
|
||||||
|
if (!(matchPattern(op.getKernelSize(),
|
||||||
|
m_TorchListOfConstantInts(kernelSize)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const int kernel size unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int padding unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int dilation unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const bool ceil_mode unsupported!");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
|
|
||||||
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - 1);
|
||||||
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
|
stablehloDilation.begin() + inputRank - 1);
|
||||||
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
|
stablehloKernelSize.begin() + inputRank - 1);
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
|
|
||||||
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||||
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||||
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||||
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloPadding);
|
||||||
|
DenseI64ArrayAttr baseDilations;
|
||||||
|
|
||||||
|
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||||
|
if (failed(inputShapeInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
}
|
||||||
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), inputShapeVec);
|
||||||
|
|
||||||
|
// no need to reshape here for max_pool_1d. Need to make sure the iota
|
||||||
|
// dimension. dim=inputRank-2 or dim=inputRank-1?
|
||||||
|
auto indexTensor =
|
||||||
|
rewriter
|
||||||
|
.create<stablehlo::DynamicIotaOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||||
|
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
|
||||||
|
.getResult();
|
||||||
|
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
|
|
||||||
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
|
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||||
|
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||||
|
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
// add block.
|
||||||
|
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||||
|
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||||
|
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||||
|
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||||
|
block.addArgument(blockValArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockValArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||||
|
auto *firstValArg = block.args_begin();
|
||||||
|
auto *firstIdxArg = std::next(firstValArg);
|
||||||
|
auto *secondValArg = std::next(firstIdxArg);
|
||||||
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
|
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||||
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||||
|
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||||
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
|
}
|
||||||
|
|
||||||
|
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||||
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||||
|
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||||
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
|
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
|
compareGeDirectionAttr, compareTypeAttr);
|
||||||
|
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
|
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||||
|
|
||||||
|
// Get smaller index if compared values are equal.
|
||||||
|
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
|
compareEqDirectionAttr, compareTypeAttr);
|
||||||
|
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||||
|
*secondIdxArg);
|
||||||
|
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||||
|
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||||
|
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||||
|
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(
|
||||||
|
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// AtenMaxPool2dWithIndicesOp
|
// AtenMaxPool2dWithIndicesOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
|
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
|
||||||
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||||
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||||
#undef INSERT_ATEN_POOLING_PATTERN
|
#undef INSERT_ATEN_POOLING_PATTERN
|
||||||
|
|
|
@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenAllOp>(op)) {
|
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||||
auto constAttr =
|
auto constAttr =
|
||||||
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
||||||
AtenLinalgVectorNormOp>(op)) {
|
AtenLinalgVectorNormOp>(op)) {
|
||||||
result = rewriter.create<stablehlo::AddOp>(
|
result = rewriter.create<stablehlo::AddOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
} else if (isa<AtenAllOp>(op)) {
|
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||||
result = rewriter.create<stablehlo::AndOp>(
|
result = rewriter.create<stablehlo::AndOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
||||||
|
@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
||||||
options)
|
options)
|
||||||
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
||||||
|
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
|
||||||
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -161,12 +161,70 @@ public:
|
||||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
|
||||||
|
unsigned getBitWidth(Type type) const {
|
||||||
|
if (auto complexTy = dyn_cast<ComplexType>(type))
|
||||||
|
return 2 * getBitWidth(complexTy.getElementType());
|
||||||
|
return type.getIntOrFloatBitWidth();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!rankType)
|
if (!rankType)
|
||||||
return op.emitError("Only ranked tensor types are currently supported");
|
return op.emitError("Only ranked tensor types are currently supported.");
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// support AtenViewDtypeOp
|
||||||
|
if (isa<AtenViewDtypeOp>(op)) {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
|
||||||
|
|
||||||
|
// infer the result shape
|
||||||
|
auto operandElt = rankType.getElementType();
|
||||||
|
auto targetElt = baseResultTy.getDtype();
|
||||||
|
auto operandEltBitWidth = getBitWidth(operandElt);
|
||||||
|
auto targetEltBitWidth = getBitWidth(targetElt);
|
||||||
|
auto operandSizes = rankType.getShape();
|
||||||
|
SmallVector<int64_t> castShape(operandSizes);
|
||||||
|
if (operandEltBitWidth > targetEltBitWidth) {
|
||||||
|
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
|
||||||
|
castShape.push_back(last_size);
|
||||||
|
} else if (operandEltBitWidth < targetEltBitWidth) {
|
||||||
|
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
|
||||||
|
if (!ShapedType::isDynamic(castShape.back()) and
|
||||||
|
last_size != castShape.back()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The last dim size is not equal to targetEltBitWidth / "
|
||||||
|
"operandEltBitWidth.");
|
||||||
|
} else {
|
||||||
|
castShape.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
baseResultTy);
|
||||||
|
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only support static output shape.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto castType =
|
||||||
|
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
|
||||||
|
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
|
||||||
|
loc,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
castType),
|
||||||
|
self);
|
||||||
|
|
||||||
|
auto reshape =
|
||||||
|
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reshape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// collect Value of dims
|
// collect Value of dims
|
||||||
SmallVector<Value, 4> dimSizes;
|
SmallVector<Value, 4> dimSizes;
|
||||||
|
@ -174,7 +232,6 @@ public:
|
||||||
return op.emitError("Dims size must be a list of Scalar");
|
return op.emitError("Dims size must be a list of Scalar");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op,
|
op,
|
||||||
|
@ -236,6 +293,13 @@ public:
|
||||||
SmallVector<Value, 4> &dimSizes) const;
|
SmallVector<Value, 4> &dimSizes) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
|
||||||
|
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||||
|
SmallVector<Value, 4> &dimSizes) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
||||||
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||||
|
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||||
|
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
|
||||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||||
#undef INSERT_VIEW_OP_PATTERN
|
#undef INSERT_VIEW_OP_PATTERN
|
||||||
|
|
|
@ -1497,6 +1497,79 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto resultType = cast<RankedTensorType>(
|
||||||
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
|
Type elementType = resultType.getElementType();
|
||||||
|
Type inputElementType =
|
||||||
|
cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
|
||||||
|
// Converting the input element type to the result's element type.
|
||||||
|
// The only possible mismatch would be when the input element type is an
|
||||||
|
// integer but not `si64`. Therefore, we directly convert the input to
|
||||||
|
// `si64`. Rest all cases are handled in the dtype definition for this op.
|
||||||
|
if (elementType != inputElementType) {
|
||||||
|
Value torchInput = convertTensorToDtype(
|
||||||
|
rewriter, loc, op.getSelf(),
|
||||||
|
rewriter.getIntegerType(64, IntegerType::Signed));
|
||||||
|
input = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc, typeConverter->convertType(torchInput.getType()),
|
||||||
|
torchInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t inputRank = resultType.getRank();
|
||||||
|
Value dtype = op.getDtype();
|
||||||
|
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: dtype argument not supported");
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only constant dim value is supported");
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "invalid dim");
|
||||||
|
|
||||||
|
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
|
||||||
|
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
|
||||||
|
|
||||||
|
SmallVector<Value> accSizes(sizes);
|
||||||
|
accSizes.erase(accSizes.begin() + dim);
|
||||||
|
SmallVector<int64_t> accStatic(
|
||||||
|
makeShapeTorchCompatible(resultType.getShape()));
|
||||||
|
accStatic.erase(accStatic.begin() + dim);
|
||||||
|
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
|
||||||
|
Type accType =
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
|
||||||
|
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
|
||||||
|
|
||||||
|
Value result = createTMTensorScanOp(
|
||||||
|
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
||||||
|
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
||||||
|
Value prod =
|
||||||
|
(isa<mlir::FloatType>(input.getType())
|
||||||
|
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
|
||||||
|
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
|
||||||
|
b.create<TMTensor::YieldOp>(loc, prod);
|
||||||
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -2240,6 +2313,8 @@ public:
|
||||||
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
target.addIllegalOp<AtenCumsumOp>();
|
||||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenCumprodOp>();
|
||||||
|
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
||||||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
|
|
@ -153,9 +153,15 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Unable to extract the scalar constant");
|
"Unable to extract the scalar constant");
|
||||||
|
|
||||||
|
int64_t numElem = 1;
|
||||||
|
for (int64_t dim : dshape)
|
||||||
|
numElem *= dim;
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(dtype)) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
tosaTensor =
|
||||||
(isFloat ? doubleValue : intValue),
|
tosa::getConstTensor<float>(
|
||||||
|
rewriter, op,
|
||||||
|
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
|
||||||
dshape, dtype)
|
dshape, dtype)
|
||||||
.value();
|
.value();
|
||||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
|
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
bool d = isFloat ? static_cast<bool>(doubleValue)
|
bool d = isFloat ? static_cast<bool>(doubleValue)
|
||||||
: static_cast<bool>(intValue);
|
: static_cast<bool>(intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<bool>(
|
||||||
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
} else if (w == 32) {
|
} else if (w == 32) {
|
||||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||||
: static_cast<int32_t>(intValue);
|
: static_cast<int32_t>(intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
} else if (w == 64) {
|
} else if (w == 64) {
|
||||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
"of destination type");
|
"of destination type");
|
||||||
}
|
}
|
||||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<int64_t>(
|
||||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||||
|
@ -891,8 +900,6 @@ public:
|
||||||
if (!result)
|
if (!result)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TBD - support dtype casting.
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {result.value()});
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -2842,8 +2849,12 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>(
|
SmallVector<int32_t> dimListInt32;
|
||||||
rewriter, op.getOperation(), dimListInt, {selfRank});
|
for (auto v : dimListInt)
|
||||||
|
dimListInt32.push_back(v);
|
||||||
|
|
||||||
|
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op.getOperation(), dimListInt32, {selfRank});
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||||
|
@ -3819,6 +3830,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
|
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// Not a tensor type.
|
||||||
|
auto input = adaptor.getSelf();
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
|
if (!inputType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType inputs are currently supported");
|
||||||
|
|
||||||
|
auto index = adaptor.getIndex();
|
||||||
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
|
|
||||||
|
if (!indexType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType indices are currently supported");
|
||||||
|
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
int inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
if (indexType.getRank() == 0)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Rank 0 index tensor is currently not supported");
|
||||||
|
|
||||||
|
// Dynamic shape check
|
||||||
|
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "AtenIndexSelectOp: support for dynamic input "
|
||||||
|
"shape not implemented");
|
||||||
|
|
||||||
|
// index i64 to i32 for tosa compatible
|
||||||
|
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||||
|
index = rewriter.create<tosa::CastOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(indexType.getShape(),
|
||||||
|
rewriter.getIntegerType(32)),
|
||||||
|
index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get positive dim
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Value `dim` should be a torch constant int");
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
|
||||||
|
|
||||||
|
// Get the output type
|
||||||
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
|
// Reshape and expand the index tensor to have same rank and same dimensions
|
||||||
|
// (except for the targeted dim) as the input
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// Input shape = (4, 5, 6)
|
||||||
|
// Index vector shape = (2)
|
||||||
|
// Targeted dim = 1
|
||||||
|
// Reshaped and expanded index vector shape = (4, 2, 6)
|
||||||
|
//
|
||||||
|
// By reshaping and expanding the index vector, we can supply it into the
|
||||||
|
// gather op to mimic the functionality of aten.index_select
|
||||||
|
SmallVector<int64_t> indicesInputRankShape;
|
||||||
|
for (int64_t i = 0; i < inputRank; i++) {
|
||||||
|
if (i == dim) {
|
||||||
|
indicesInputRankShape.push_back(indexType.getShape()[0]);
|
||||||
|
} else {
|
||||||
|
indicesInputRankShape.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indicesInputRankType =
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
|
||||||
|
rewriter.getIntegerType(32));
|
||||||
|
|
||||||
|
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), indicesInputRankType, index,
|
||||||
|
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
|
||||||
|
|
||||||
|
SmallVector<int64_t> tileShape(indicesInputRankShape);
|
||||||
|
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
|
||||||
|
for (int64_t i = 0; i < inputRank; i++) {
|
||||||
|
if (tileShape[i] == 1 && i != dim) {
|
||||||
|
tileShape[i] = inputShape[i];
|
||||||
|
expandedIndicesShape[i] = inputShape[i];
|
||||||
|
} else {
|
||||||
|
tileShape[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tileType =
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
|
||||||
|
rewriter.getIntegerType(32));
|
||||||
|
|
||||||
|
auto expandedIndices = rewriter.create<tosa::TileOp>(
|
||||||
|
op->getLoc(), tileType, reshapedIndices.getResult(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(tileShape));
|
||||||
|
|
||||||
|
// convert torch style index and dim into tf style indices
|
||||||
|
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
|
||||||
|
auto indicesTf = tosa::convertTorchIndexToTfIndices(
|
||||||
|
rewriter, op, input, expandedIndices.getResult(), dim);
|
||||||
|
if (!indicesTf)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Convert TorchIndex To TfIndices failed");
|
||||||
|
|
||||||
|
// do the tf gathernd algorithm with tf style indices as input.
|
||||||
|
auto result =
|
||||||
|
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
|
||||||
|
|
||||||
|
if (!result) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||||
|
@ -5200,7 +5329,7 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
@ -5216,18 +5345,48 @@ public:
|
||||||
op, "Only Tensor types with static shapes are currently supported");
|
op, "Only Tensor types with static shapes are currently supported");
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
}
|
|
||||||
Value constOp;
|
Value fillValueTargetTensor;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
|
||||||
|
// Reshape value tensor to have same rank and shape as input
|
||||||
|
auto inputRank =
|
||||||
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
|
|
||||||
|
auto fillValue = adaptor.getValue();
|
||||||
|
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
|
||||||
|
if (!fillValueType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
|
||||||
|
auto fillValueElemTy = fillValueType.getElementType();
|
||||||
|
|
||||||
|
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
|
||||||
|
|
||||||
|
auto fillValueMatchedInputRankType = RankedTensorType::get(
|
||||||
|
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
|
||||||
|
fillValueElemTy);
|
||||||
|
|
||||||
|
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), fillValueMatchedInputRankType, fillValue,
|
||||||
|
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
|
||||||
|
|
||||||
|
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
|
||||||
|
fillValueElemTy),
|
||||||
|
fillValueMatchedInputRankTensor.getResult(),
|
||||||
|
makeShapeTorchCompatible(outType.getShape()));
|
||||||
|
} else {
|
||||||
if (failed(torchScalarToTosaTensor(
|
if (failed(torchScalarToTosaTensor(
|
||||||
rewriter, op, op.getValue(), constOp, outElemTy,
|
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
|
||||||
makeShapeTorchCompatible(outType.getShape()))))
|
makeShapeTorchCompatible(outType.getShape()))))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Supplied value must be a Scalar constant");
|
op, "Fill value must be a scalar constant");
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||||
|
fillValueTargetTensor);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -5647,8 +5806,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template to create support tril mask tensor for aten.tril
|
// Template to create supporting tril mask tensor for aten.tril
|
||||||
// legalization
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
@ -5671,28 +5829,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to get tril mask tensor based on input type
|
|
||||||
// for aten.tril legalization
|
|
||||||
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
|
||||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
|
||||||
int64_t diagonal, Type type) {
|
|
||||||
return TypeSwitch<Type, Value>(type)
|
|
||||||
.Case<mlir::FloatType>([&](auto) {
|
|
||||||
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
})
|
|
||||||
.Case<mlir::IntegerType>([&](auto intType) {
|
|
||||||
switch (intType.getWidth()) {
|
|
||||||
case 1:
|
|
||||||
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
case 32:
|
|
||||||
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
case 64:
|
|
||||||
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
}
|
|
||||||
llvm_unreachable("Invalid integer width");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Legalization for aten.tril
|
// Legalization for aten.tril
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
@ -5740,14 +5876,31 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||||
|
|
||||||
// Define shape for mask tensor based on rank
|
// Define shape for mask tensor based on rank
|
||||||
SmallVector<int64_t> constShape;
|
SmallVector<int64_t> maskShape;
|
||||||
for (auto i = 0; i < selfRank - 2; i++)
|
for (auto i = 0; i < selfRank - 2; i++)
|
||||||
constShape.push_back(1);
|
maskShape.push_back(1);
|
||||||
constShape.push_back(h);
|
maskShape.push_back(h);
|
||||||
constShape.push_back(w);
|
maskShape.push_back(w);
|
||||||
|
|
||||||
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
|
||||||
resultType.getElementType());
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return createTrilMask<float>(rewriter, op, maskShape,
|
||||||
|
h, w, diagonal);
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return createTrilMask<bool>(rewriter, op, maskShape,
|
||||||
|
h, w, diagonal);
|
||||||
|
case 32:
|
||||||
|
return createTrilMask<int32_t>(
|
||||||
|
rewriter, op, maskShape, h, w, diagonal);
|
||||||
|
case 64:
|
||||||
|
return createTrilMask<int64_t>(
|
||||||
|
rewriter, op, maskShape, h, w, diagonal);
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||||
/*shift=*/0);
|
/*shift=*/0);
|
||||||
|
@ -5755,6 +5908,311 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.flip
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||||
|
AtenFlipOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are currently supported");
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only constant dims are currently supported");
|
||||||
|
|
||||||
|
auto selfRank = selfTy.getRank();
|
||||||
|
|
||||||
|
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
Value result = self;
|
||||||
|
|
||||||
|
for (auto &dim : dims) {
|
||||||
|
dim = toPositiveDim(dim, selfRank);
|
||||||
|
if (!isValidDim(dim, selfRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||||
|
|
||||||
|
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
|
||||||
|
static_cast<int32_t>(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.round:
|
||||||
|
// Rounds elements of input to the nearest integer.
|
||||||
|
// Implements "round half to even" to break ties when a number is equidistant
|
||||||
|
// from two integers.
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
|
||||||
|
AtenRoundOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// To round to the nearest integer, we will consider the fractional part of
|
||||||
|
// the input element (= input element - integer part of element). If the
|
||||||
|
// fractional part is smaller than 0.5, round the number down. If the
|
||||||
|
// fractional part is 0.5, apply "round half to even" rule. If the fractional
|
||||||
|
// part is greater than 0.5, round up.
|
||||||
|
//
|
||||||
|
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
|
||||||
|
// res = floor(input)
|
||||||
|
// else:
|
||||||
|
// res = ceil(input)
|
||||||
|
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
|
||||||
|
|
||||||
|
auto resultTy =
|
||||||
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
|
auto boolTy =
|
||||||
|
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
|
||||||
|
|
||||||
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
auto oneHalf =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
|
||||||
|
|
||||||
|
auto two =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
|
||||||
|
|
||||||
|
auto floorInput =
|
||||||
|
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
|
||||||
|
|
||||||
|
// input - floor(input)
|
||||||
|
auto fractionalPart = rewriter.create<tosa::SubOp>(
|
||||||
|
op->getLoc(), resultTy, self, floorInput.getResult());
|
||||||
|
|
||||||
|
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
|
||||||
|
|
||||||
|
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
|
||||||
|
|
||||||
|
auto floorDivResult = rewriter.create<tosa::FloorOp>(
|
||||||
|
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
|
||||||
|
|
||||||
|
// (floor(input) // 2) * 2
|
||||||
|
auto evenComparison = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
|
||||||
|
|
||||||
|
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
|
||||||
|
auto floorInputEven = rewriter.create<tosa::EqualOp>(
|
||||||
|
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
|
||||||
|
|
||||||
|
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
|
||||||
|
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
|
||||||
|
|
||||||
|
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
|
||||||
|
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
|
||||||
|
|
||||||
|
// (frac == 0.5) && (floor(input) % 2 == 0)
|
||||||
|
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
|
||||||
|
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
|
||||||
|
floorInputEven.getResult());
|
||||||
|
|
||||||
|
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
|
||||||
|
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
|
||||||
|
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
|
||||||
|
fracEqualOneHalfCond.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
|
||||||
|
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
|
||||||
|
ceilInput.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||||
|
template <typename T>
|
||||||
|
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
int64_t offset) {
|
||||||
|
SmallVector<T> vec;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < h; i++) {
|
||||||
|
for (int64_t j = 0; j < w; j++) {
|
||||||
|
// Positive offset value moves above the main diagonal, while negative
|
||||||
|
// diagonal value moves below the main diagonal.
|
||||||
|
if (i + offset == j) {
|
||||||
|
vec.push_back(static_cast<T>(1));
|
||||||
|
} else {
|
||||||
|
vec.push_back(static_cast<T>(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.diagonal
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
|
AtenDiagonalOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a ranked tensor type
|
||||||
|
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
|
// Rank below 2 not accepted
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
if (selfRank <= 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||||
|
|
||||||
|
if (!selfType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only static shapes are supported");
|
||||||
|
|
||||||
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
if (!resultType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||||
|
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
auto resultElemTy = resultType.getElementType();
|
||||||
|
|
||||||
|
int64_t offset, dim1, dim2;
|
||||||
|
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||||
|
offset = 0;
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
|
||||||
|
dim1 = 0;
|
||||||
|
} else {
|
||||||
|
dim1 = toPositiveDim(dim1, selfRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
|
||||||
|
dim2 = 1;
|
||||||
|
} else {
|
||||||
|
dim2 = toPositiveDim(dim2, selfRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||||
|
int64_t h = selfShape[dim1];
|
||||||
|
int64_t w = selfShape[dim2];
|
||||||
|
|
||||||
|
// Overflowing offset not supported
|
||||||
|
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Offset greater or equal than shape not supported");
|
||||||
|
|
||||||
|
int64_t targetDim1 = selfRank - 2;
|
||||||
|
int64_t targetDim2 = selfRank - 1;
|
||||||
|
|
||||||
|
Value selfTransposed = self;
|
||||||
|
SmallVector<int64_t> transposedInputShape = selfShape;
|
||||||
|
RankedTensorType transposedInputType = selfType;
|
||||||
|
|
||||||
|
// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
|
||||||
|
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
|
||||||
|
// we can consistently create the diagonal mask tensor.
|
||||||
|
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
|
||||||
|
SmallVector<int32_t> transposedDims;
|
||||||
|
transposedInputShape.clear();
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < selfRank; ++i) {
|
||||||
|
if (i == dim1 || i == dim2)
|
||||||
|
continue;
|
||||||
|
transposedDims.push_back(i);
|
||||||
|
}
|
||||||
|
transposedDims.push_back(dim1);
|
||||||
|
transposedDims.push_back(dim2);
|
||||||
|
|
||||||
|
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op,
|
||||||
|
/*vec=*/transposedDims,
|
||||||
|
/*shape=*/{static_cast<int32_t>(selfRank)});
|
||||||
|
|
||||||
|
for (auto &dim : transposedDims)
|
||||||
|
transposedInputShape.push_back(selfShape[dim]);
|
||||||
|
|
||||||
|
transposedInputType = RankedTensorType::get(
|
||||||
|
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||||
|
|
||||||
|
selfTransposed = rewriter.create<tosa::TransposeOp>(
|
||||||
|
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define shape for mask tensor based on rank
|
||||||
|
SmallVector<int64_t> maskShape;
|
||||||
|
for (auto i = 0; i < selfRank - 2; i++)
|
||||||
|
maskShape.push_back(1);
|
||||||
|
maskShape.push_back(h);
|
||||||
|
maskShape.push_back(w);
|
||||||
|
|
||||||
|
Value diagonalMask =
|
||||||
|
TypeSwitch<Type, Value>(resultElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
case 32:
|
||||||
|
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
case 64:
|
||||||
|
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
|
Value diagonalTensor = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
|
||||||
|
auto targetReduceDim = resultShape[resultType.getRank() - 1];
|
||||||
|
|
||||||
|
// If transposedInputShape[targetDim1] (or h) is greater than the innermost
|
||||||
|
// dim of the result, we won't get the correct shape when we reduce sum along
|
||||||
|
// the innermost dim to get the result. Therefore, we have to slice the
|
||||||
|
// transposed tensor so that transposedInputShape[targetDim1] ==
|
||||||
|
// targetReduceDim.
|
||||||
|
if (h > targetReduceDim) {
|
||||||
|
transposedInputShape[targetDim1] = targetReduceDim;
|
||||||
|
transposedInputType = RankedTensorType::get(
|
||||||
|
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||||
|
SmallVector<int64_t> startSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> sizeSlice =
|
||||||
|
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
|
||||||
|
if (offset < 0)
|
||||||
|
startSlice[targetDim1] = std::abs(offset);
|
||||||
|
diagonalTensor = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), transposedInputType, diagonalTensor,
|
||||||
|
rewriter.getDenseI64ArrayAttr(startSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(sizeSlice));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Reduce Sum to get the result
|
||||||
|
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||||
|
auto reduceDimAttr =
|
||||||
|
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
|
||||||
|
auto result =
|
||||||
|
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
|
||||||
|
reduceDimAttr, /*keep_dims=*/false);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result.value());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -5986,11 +6444,13 @@ public:
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
|
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||||
#undef INSERT_FILL_SCALAR_PATTERN
|
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||||
|
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||||
|
#undef INSERT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
@ -6060,6 +6520,10 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -23,6 +23,15 @@ namespace tosa {
|
||||||
|
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
// This function is a helper for `convertTorchIndexToTfIndices`.
|
||||||
|
//
|
||||||
|
// We convert PyTorch index to TensorFlow-style indices so that we can use
|
||||||
|
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
|
||||||
|
// and Scatter operators to TOSA using TensorFlow-style indices.
|
||||||
|
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
|
||||||
|
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
|
||||||
|
// to gather/scatter elements, while in TensorFlow, the indices point directly
|
||||||
|
// to positions that you want to gather/scatter elements.
|
||||||
std::optional<Value>
|
std::optional<Value>
|
||||||
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
|
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
|
||||||
|
@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
unsigned indexRank = indexShape.size();
|
unsigned indexRank = indexShape.size();
|
||||||
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
|
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
|
||||||
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
|
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
|
||||||
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
|
|
||||||
|
|
||||||
// Create torch.meshgrid inputs
|
// Create torch.meshgrid inputs
|
||||||
// Example: indexShape=[1,4,2]
|
// Example: indexShape=[1,4,2]
|
||||||
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
|
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
|
||||||
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
|
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
|
||||||
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
|
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
|
||||||
for (int i = 0; i < indexShape[dim]; i++) {
|
for (int i = 0; i < indexShape[dim]; i++)
|
||||||
indicesMetaElement.push_back(i);
|
indicesMetaElement.push_back(i);
|
||||||
}
|
|
||||||
|
|
||||||
// Compute total number of meta element repeat times:
|
int preDimMetaElementRepeatTimes = 1;
|
||||||
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
|
int postDimMetaElementRepeatTimes = 1;
|
||||||
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
|
|
||||||
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
|
|
||||||
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
|
|
||||||
for (int i = 0; i < static_cast<int>(indexRank); i++) {
|
|
||||||
if (i == dim) {
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
indicesMetaElementRepeatTimes *= indexShape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dim != static_cast<int>(indexShape.size()) - 1) {
|
// Compute total number of times meta element range should repeat
|
||||||
// Create one dim indices for index except for last dim
|
// = product(indexShape[0:dim])
|
||||||
// Create indices raw vector.
|
// dim0: preDimMetaElementRepeatTimes = 1
|
||||||
// torch.stack(torch.meshgrid)
|
// dim1: preDimMetaElementRepeatTimes = 1
|
||||||
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
|
// dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
|
||||||
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
|
for (int i = 0; i < dim; i++)
|
||||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
preDimMetaElementRepeatTimes *= indexShape[i];
|
||||||
elementId++) {
|
|
||||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
// Compute total number of times meta element repeat
|
||||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
// = product(indexShape[dim+1:indexRank])
|
||||||
}
|
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
|
||||||
}
|
// dim1: postDimMetaElementRepeatTimes = 2
|
||||||
} else { // Create the one dim indices for last dim of index
|
// dim2: postDimMetaElementRepeatTimes = 1
|
||||||
// Create indices raw vector
|
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
|
||||||
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
|
postDimMetaElementRepeatTimes *= indexShape[i];
|
||||||
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
|
|
||||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
// Example using dim1:
|
||||||
|
// preDimMetaElementRepeatTimes = 1
|
||||||
|
// postDimMetaElementRepeatTimes = 2
|
||||||
|
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||||
|
// [0 0 1 1 2 2 3 3]
|
||||||
|
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||||
|
// [0 0 1 1 2 2 3 3]
|
||||||
|
//
|
||||||
|
// Let's use a clearer example:
|
||||||
|
// indexShape = [3, 4, 2]
|
||||||
|
// Target dim = 1
|
||||||
|
// => preDimMetaElementRepeatTimes = 3
|
||||||
|
// postDimMetaElementRepeatTimes = 2
|
||||||
|
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||||
|
// [0 0 1 1 2 2]
|
||||||
|
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||||
|
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
|
||||||
|
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
|
||||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
||||||
elementId++) {
|
elementId++) {
|
||||||
|
for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
|
||||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
indicesVec.push_back(indicesMetaElement[elementId]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,12 +132,28 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
Type elemTy) {
|
Type elemTy) {
|
||||||
Value initTensor =
|
Value initTensor =
|
||||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||||
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
|
||||||
Value c0 =
|
Type fillValElemTy = elemTy;
|
||||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||||
|
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||||
|
|
||||||
|
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
|
||||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
Type elemTy) {
|
||||||
|
Value initTensor =
|
||||||
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||||
|
|
||||||
|
Type fillValElemTy = elemTy;
|
||||||
|
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||||
|
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||||
|
|
||||||
|
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
|
||||||
|
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
||||||
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
||||||
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
||||||
|
|
|
@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult BindSymbolicShapeOp::verify() {
|
LogicalResult BindSymbolicShapeOp::verify() {
|
||||||
if (getShapeSymbols().empty())
|
if (getShapeSymbols().size() !=
|
||||||
return emitOpError() << "requires non-empty shapeSymbols";
|
getShapeExpressions().getValue().getNumSymbols())
|
||||||
|
return emitOpError()
|
||||||
|
<< "requires equal number of shape symbol args and symbol args to "
|
||||||
|
"the attached affine map, since they are 1:1 mapped";
|
||||||
|
|
||||||
for (auto symbol : getShapeSymbols()) {
|
for (auto symbol : getShapeSymbols()) {
|
||||||
Operation *definingOp = symbol.getDefiningOp();
|
Operation *definingOp = symbol.getDefiningOp();
|
||||||
|
|
|
@ -9200,6 +9200,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
@ -10352,6 +10355,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||||
|
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
|
@ -11895,6 +11910,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %1 : !torch.int\n"
|
" return %1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||||
|
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int4 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %2#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
@ -14663,6 +14697,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
@ -15601,6 +15639,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %2 : !torch.int\n"
|
" return %2 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %str = torch.constant.str \"size must be less than or equal to {}\"\n"
|
||||||
|
" %false = torch.constant.bool false\n"
|
||||||
|
" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str_1 = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %str_2 = torch.constant.str \"dimension out of range of {}\"\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||||
|
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %3 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||||
|
" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n"
|
||||||
|
" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %4 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %5 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||||
|
" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %15 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||||
|
" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %6 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||||
|
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||||
|
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %8 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n"
|
||||||
|
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||||
|
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||||
|
" %14 = torch.aten.append.t %12, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %12 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
"}\n"
|
"}\n"
|
||||||
"";
|
"";
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -7298,6 +7298,85 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
|
||||||
|
// op.
|
||||||
|
class DecomposeAtenAdaptiveMaxPool1dOp
|
||||||
|
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
|
||||||
|
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
|
||||||
|
Value input = op.getSelf();
|
||||||
|
std::optional<unsigned> maybeRank = getTensorRank(input);
|
||||||
|
if (!maybeRank) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||||
|
}
|
||||||
|
unsigned rank = *maybeRank;
|
||||||
|
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(rank - 1));
|
||||||
|
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
|
||||||
|
|
||||||
|
Value outputShape = op.getOutputSize();
|
||||||
|
SmallVector<Value> outputShapeSizesTorchInt;
|
||||||
|
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
||||||
|
Value outputSize = outputShapeSizesTorchInt[0];
|
||||||
|
|
||||||
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
|
||||||
|
int64_t outputSizeInt;
|
||||||
|
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "the output size of adaptive_max_pool1d must be a constant int");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 1> kernelSize;
|
||||||
|
if (outputSizeInt == 1) {
|
||||||
|
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
|
||||||
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
|
kernelSize.push_back(
|
||||||
|
inputShape[rank - 1] == kUnknownSize
|
||||||
|
? inputSize
|
||||||
|
: rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
|
||||||
|
} else {
|
||||||
|
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
||||||
|
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
|
||||||
|
rewriter.create<RuntimeAssertOp>(
|
||||||
|
loc, cond,
|
||||||
|
"unimplemented: only support cases where input and output size are "
|
||||||
|
"equal for non-unit output size");
|
||||||
|
}
|
||||||
|
kernelSize.push_back(constantOne);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
||||||
|
Value strideList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||||
|
ValueRange{constantOne});
|
||||||
|
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||||
|
ValueRange{constantZero});
|
||||||
|
Value dialationList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||||
|
ValueRange{constantOne});
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
|
||||||
|
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
||||||
|
paddingSizeList, dialationList,
|
||||||
|
/*ceil_mode=*/constantFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
|
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
|
||||||
|
|
||||||
|
@ -8720,6 +8799,77 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
||||||
|
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
||||||
|
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto self = op.getSelf();
|
||||||
|
auto target = op.getTarget();
|
||||||
|
auto posWeight = op.getPosWeight();
|
||||||
|
auto weight = op.getWeight();
|
||||||
|
auto reduction = op.getReduction();
|
||||||
|
|
||||||
|
Value loss;
|
||||||
|
auto one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto _one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
|
||||||
|
auto _target =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
|
||||||
|
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
|
||||||
|
_target, one, one);
|
||||||
|
Value mm =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
|
||||||
|
Value logSigm =
|
||||||
|
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(posWeight.getType())) {
|
||||||
|
auto logWeight = rewriter.create<AtenAddScalarOp>(
|
||||||
|
loc, posWeight.getType(),
|
||||||
|
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
|
||||||
|
one, one),
|
||||||
|
one, one);
|
||||||
|
loss = rewriter.create<AtenSubTensorOp>(
|
||||||
|
loc, mm.getType(), mm,
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
|
||||||
|
logSigm),
|
||||||
|
one);
|
||||||
|
} else {
|
||||||
|
loss =
|
||||||
|
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
|
loss =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply loss reduction.
|
||||||
|
int64_t reductionInt;
|
||||||
|
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value res;
|
||||||
|
if (reductionInt == 1) {
|
||||||
|
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
|
||||||
|
} else if (reductionInt == 2) {
|
||||||
|
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
|
||||||
|
} else {
|
||||||
|
res = loss;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||||
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
||||||
|
@ -9801,6 +9951,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
||||||
|
@ -9856,6 +10007,8 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
||||||
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
||||||
|
|
|
@ -530,11 +530,139 @@ public:
|
||||||
none, none, none, none);
|
none, none, none, none);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
|
||||||
|
if (squeezeOp && resultTy.getSizes().size() == 1) {
|
||||||
|
rewriter.replaceOp(op, squeezeOp.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
|
||||||
|
// [?,...,?,factor0,factor1] to unflatten, and views like
|
||||||
|
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
|
||||||
|
// possible to infer that all but last shared dim match
|
||||||
|
// TODO: move this to an actual canonicalizer for view after deleting the
|
||||||
|
// conflicting decompositions for flatten/unflatten -> view.
|
||||||
|
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenViewOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Value> viewSizes;
|
||||||
|
if (failed(getListOperands(op.getSize(), viewSizes)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "view size must be from a list construct");
|
||||||
|
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
|
||||||
|
if (!selfTy || !selfTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
|
||||||
|
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
|
||||||
|
if (!resultTy || !resultTy.hasSizes() ||
|
||||||
|
resultTy.getSizes().size() != viewSizes.size())
|
||||||
|
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
|
||||||
|
int64_t inRank = selfTy.getSizes().size();
|
||||||
|
int64_t outRank = resultTy.getSizes().size();
|
||||||
|
|
||||||
|
SmallVector<int64_t> sizes(selfTy.getSizes());
|
||||||
|
int64_t endMatchingDim = -1;
|
||||||
|
// input sizes vs. provided view sizes comparison loop
|
||||||
|
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
|
||||||
|
int64_t providedSize;
|
||||||
|
bool providedStatic =
|
||||||
|
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
|
||||||
|
// if sizes[i] is static, it must match a constant in viewSizes[i]
|
||||||
|
if (sizes[i] != Torch::kUnknownSize) {
|
||||||
|
if (!providedStatic)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: found static input dim, but unable to match "
|
||||||
|
"provided view size on a constant. See position : " +
|
||||||
|
std::to_string(i));
|
||||||
|
if (providedSize != sizes[i]) {
|
||||||
|
endMatchingDim = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// the remaining assumes sizes[i] is dynamic
|
||||||
|
// if provided dim is static, we can't verify it is a flatten/unflatten
|
||||||
|
// unless -1
|
||||||
|
if (i == outRank - 1 && providedStatic && providedSize == -1) {
|
||||||
|
endMatchingDim = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (providedStatic)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unexpected static view dim corresponding to dynamic input dim "
|
||||||
|
"at position : " +
|
||||||
|
std::to_string(i));
|
||||||
|
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
|
||||||
|
// if we don't have a size int op on self, fail
|
||||||
|
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected dynamic view dim to come from a corresponding "
|
||||||
|
"size.int op. See position : " +
|
||||||
|
std::to_string(i));
|
||||||
|
int64_t dim;
|
||||||
|
// if the dim of the size int op doesn't match, fail
|
||||||
|
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
|
||||||
|
dim != i)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"size int op dim cannot be matched to current dim at position : " +
|
||||||
|
std::to_string(i));
|
||||||
|
// passing the previous checks means viewSizes[i] = aten.size.int(self,
|
||||||
|
// i), so continue
|
||||||
|
}
|
||||||
|
// if all dims match and the ranks are equal, fold
|
||||||
|
if (endMatchingDim == -1 && inRank == outRank) {
|
||||||
|
rewriter.replaceOp(op, op.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (endMatchingDim > -1 && inRank > outRank) {
|
||||||
|
// only support flattening last dim
|
||||||
|
if (endMatchingDim != outRank - 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: output has more than back dim mismatching");
|
||||||
|
// flatten
|
||||||
|
Value start =
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||||
|
Value end =
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
|
||||||
|
op, resultTy, op.getSelf(), start, end);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (endMatchingDim > -1 && inRank < outRank) {
|
||||||
|
// only support unflattening last dim
|
||||||
|
if (endMatchingDim != inRank - 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: input has more than back dim mismatching");
|
||||||
|
// unflatten
|
||||||
|
Value dim =
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||||
|
Value primList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
op.getLoc(), op.getSize().getType(),
|
||||||
|
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
|
||||||
|
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
|
||||||
|
op, resultTy, op.getSelf(), dim, primList);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// examples that might reach this:
|
||||||
|
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
|
||||||
|
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
|
||||||
|
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
|
||||||
|
", inRank=" + std::to_string(inRank) +
|
||||||
|
", outRank=" + std::to_string(outRank));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
|
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
|
||||||
public:
|
public:
|
||||||
|
@ -561,12 +689,18 @@ public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns
|
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
||||||
.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
|
||||||
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
||||||
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
|
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
|
||||||
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
|
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
|
||||||
FoldAtenWhereSelf, RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
|
||||||
|
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenEqIntOp>,
|
||||||
|
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenFullOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||||
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||||
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||||
RemoveUnusedPattern<Torch::ConstantBoolOp>,
|
RemoveUnusedPattern<Torch::ConstantBoolOp>,
|
||||||
|
|
|
@ -90,7 +90,28 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
return torch_upstream::ScalarType::Float8_e5m2fnuz;
|
return torch_upstream::ScalarType::Float8_e5m2fnuz;
|
||||||
if (isa<Float8E4M3FNUZType>(type))
|
if (isa<Float8E4M3FNUZType>(type))
|
||||||
return torch_upstream::ScalarType::Float8_e4m3fnuz;
|
return torch_upstream::ScalarType::Float8_e4m3fnuz;
|
||||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
std::string errorMsg = "Unhandled type in getScalarTypeForType: ";
|
||||||
|
llvm::raw_string_ostream os(errorMsg);
|
||||||
|
type.print(os);
|
||||||
|
// os << "\nType ID: " << type.getTypeID();
|
||||||
|
os << "\nType properties:";
|
||||||
|
os << "\n Is integer: " << (type.isInteger() ? "yes" : "no");
|
||||||
|
os << "\n Is float: "
|
||||||
|
<< (type.isIntOrFloat() && !type.isInteger() ? "yes" : "no");
|
||||||
|
os << "\n Is index: " << (type.isIndex() ? "yes" : "no");
|
||||||
|
os << "\n Bit width: "
|
||||||
|
<< (type.isIntOrFloat() ? std::to_string(type.getIntOrFloatBitWidth())
|
||||||
|
: "N/A");
|
||||||
|
os << "\n Is signless: " << (type.isSignlessInteger() ? "yes" : "no");
|
||||||
|
os << "\n Is signed: " << (type.isSignedInteger() ? "yes" : "no");
|
||||||
|
// special error message for unsigned integer
|
||||||
|
if (type.isUnsignedInteger()) {
|
||||||
|
os << "\n Is unsigned: yes";
|
||||||
|
os << "\nUnsigned integer support is currently spotty. Please seeheck "
|
||||||
|
"https://github.com/llvm/torch-mlir/issues/3720 "
|
||||||
|
"for more details.";
|
||||||
|
}
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(errorMsg));
|
||||||
}
|
}
|
||||||
Type Torch::getTypeForTorchType(
|
Type Torch::getTypeForTorchType(
|
||||||
MLIRContext *context, Type type,
|
MLIRContext *context, Type type,
|
||||||
|
@ -257,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) {
|
||||||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||||
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
||||||
AtenPixelShuffleOp, AtenDiagonalOp>(op);
|
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||||
|
|
|
@ -79,6 +79,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
#### General TorchDynamo/PyTorch errors
|
#### General TorchDynamo/PyTorch errors
|
||||||
# torch._dynamo.exc.Unsupported: Tensor.item
|
# torch._dynamo.exc.Unsupported: Tensor.item
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
||||||
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
||||||
# https://github.com/pytorch/pytorch/issues/89629
|
# https://github.com/pytorch/pytorch/issues/89629
|
||||||
|
@ -432,6 +433,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
"DeformConv2D_basic",
|
"DeformConv2D_basic",
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
|
@ -504,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
"WeightNormInterfaceModule_basic",
|
"WeightNormInterfaceModule_basic",
|
||||||
# Error: `aten.as_strided` op is not supported
|
# Error: `aten.as_strided` op is not supported
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
|
@ -588,6 +591,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"AdaptiveAvgPool3dDynamic_basic",
|
"AdaptiveAvgPool3dDynamic_basic",
|
||||||
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
||||||
"AdaptiveMaxPool1dDynamic_basic",
|
"AdaptiveMaxPool1dDynamic_basic",
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"AdaptiveMaxPool1dStatic_basic",
|
"AdaptiveMaxPool1dStatic_basic",
|
||||||
"AdaptiveMaxPool2dDynamicNoBatch_basic",
|
"AdaptiveMaxPool2dDynamicNoBatch_basic",
|
||||||
"AdaptiveMaxPool2dDynamicWithIndices_basic",
|
"AdaptiveMaxPool2dDynamicWithIndices_basic",
|
||||||
|
@ -666,6 +670,10 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
|
"CumprodInputDtypeInt32Module_basic",
|
||||||
|
"CumprodStaticModule_basic",
|
||||||
|
"CumprodStaticNegativeDimModule_basic",
|
||||||
"DeformConv2D_basic",
|
"DeformConv2D_basic",
|
||||||
"DeterminantBatchedModule_F32",
|
"DeterminantBatchedModule_F32",
|
||||||
"DeterminantDynamicModule_F32",
|
"DeterminantDynamicModule_F32",
|
||||||
|
@ -808,10 +816,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"RandnLikeDtypeModule_basic",
|
"RandnLikeDtypeModule_basic",
|
||||||
"RandnLikeModule_basic",
|
"RandnLikeModule_basic",
|
||||||
"RandnModule_basic",
|
"RandnModule_basic",
|
||||||
"ReduceAllDimBool_basic",
|
|
||||||
"ReduceAllDimEmpty_basic",
|
|
||||||
"ReduceAllDimFloat_basic",
|
|
||||||
"ReduceAllDimInt_basic",
|
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
|
@ -829,18 +833,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
# need aten.all.dim lowering to stablehlo
|
|
||||||
"SafeSoftmaxModule_basic",
|
|
||||||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
"ScaledDotProductAttentionMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
@ -926,6 +919,11 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"GCDBatchedModule_I32",
|
"GCDBatchedModule_I32",
|
||||||
"GCDDynamicModule_I32",
|
"GCDDynamicModule_I32",
|
||||||
"GCDModule_I32",
|
"GCDModule_I32",
|
||||||
|
"Unfold_Module_basic",
|
||||||
|
"Unfold_Module_Rank_4",
|
||||||
|
"Unfold_Module_Rank_Zero_basic",
|
||||||
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
|
"Unfold_Module_Dynamic_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
||||||
|
@ -1059,6 +1057,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"ContiguousModule_basic",
|
"ContiguousModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
@ -1079,6 +1078,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
|
"CumprodInputDtypeInt32Module_basic",
|
||||||
|
"CumprodStaticModule_basic",
|
||||||
|
"CumprodStaticNegativeDimModule_basic",
|
||||||
"DetachModule_basic",
|
"DetachModule_basic",
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
|
@ -1425,6 +1427,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"SliceSizeTwoStepModule_basic",
|
"SliceSizeTwoStepModule_basic",
|
||||||
"SliceStartEqEndModule_basic",
|
"SliceStartEqEndModule_basic",
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
|
"SliceStaticComplexInputModule_basic",
|
||||||
"SliceWholeTensorModule_basic",
|
"SliceWholeTensorModule_basic",
|
||||||
"SortIntListReverse_basic",
|
"SortIntListReverse_basic",
|
||||||
"SortIntList_basic",
|
"SortIntList_basic",
|
||||||
|
@ -1671,6 +1674,35 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"AtenRoundFloatHalfToEvenModule_basic",
|
||||||
|
"AtenRoundFloatModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||||
|
"Fill_TensorFloat64WithInt64Static_basic",
|
||||||
|
"FlipModuleStaticShape_basic",
|
||||||
|
"FlipModule_basic",
|
||||||
|
"FlipNegativeIndexModule_basic",
|
||||||
|
"Rot90BasicModule_basic",
|
||||||
|
"Rot90DynamicDimsModule_basic",
|
||||||
|
"Rot90MultipleRotationsModule_basic",
|
||||||
|
"Rot90NegativeEvenRotationsModule_basic",
|
||||||
|
"Rot90NegativeOddRotationsModule_basic",
|
||||||
|
"AtenLinalgCrossBroadcast_basic",
|
||||||
|
"AtenLinalgCrossCustomDim_basic",
|
||||||
|
"AtenLinalgCrossFloat_basic",
|
||||||
|
"AtenLinalgCrossInt_basic",
|
||||||
|
"AtenLinalgCrossNegativeDim_basic",
|
||||||
|
"BinaryCrossEntropyWithLogitsStaticModule_basic",
|
||||||
|
"IndexSelectNegativeDimModule_basic",
|
||||||
|
"IndexSelectSingleIdxModule_basic",
|
||||||
|
"IndexSelectTwoIdxModule_basic",
|
||||||
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
|
"IndexSelectWholeTensorModule_basic",
|
||||||
|
"DiagonalWithStaticShapeModule_basic",
|
||||||
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
|
@ -1814,7 +1846,6 @@ TOSA_PASS_SET = {
|
||||||
"ArangeStartOutModule_basic",
|
"ArangeStartOutModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"ArangeStartStepIntModule_basic",
|
"ArangeStartStepIntModule_basic",
|
||||||
"ArangeZeroElementOutputModule_basic",
|
|
||||||
"ArangeDtypeIntModule_basic",
|
"ArangeDtypeIntModule_basic",
|
||||||
"ArangeFalsePinMemoryModule_basic",
|
"ArangeFalsePinMemoryModule_basic",
|
||||||
"ArangeFloatModule_basic",
|
"ArangeFloatModule_basic",
|
||||||
|
@ -2115,7 +2146,6 @@ TOSA_PASS_SET = {
|
||||||
"NormScalarOptDimModule_basic",
|
"NormScalarOptDimModule_basic",
|
||||||
"NumToTensorFloatModule_basic",
|
"NumToTensorFloatModule_basic",
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"NumpyTRank0Module_basic",
|
|
||||||
"NumpyTRank1Module_basic",
|
"NumpyTRank1Module_basic",
|
||||||
"NumpyTRank2Module_basic",
|
"NumpyTRank2Module_basic",
|
||||||
"NumpyTRankNDynamicModule_basic",
|
"NumpyTRankNDynamicModule_basic",
|
||||||
|
@ -2127,7 +2157,6 @@ TOSA_PASS_SET = {
|
||||||
"OnesModuleInt_basic",
|
"OnesModuleInt_basic",
|
||||||
"PadModule_basic",
|
"PadModule_basic",
|
||||||
"PadWithNoneValModule_basic",
|
"PadWithNoneValModule_basic",
|
||||||
"Permute0RankModule_basic",
|
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
|
@ -2166,7 +2195,6 @@ TOSA_PASS_SET = {
|
||||||
"ScalarTensorInt64Module_basic",
|
"ScalarTensorInt64Module_basic",
|
||||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
"SiluModule_basic",
|
"SiluModule_basic",
|
||||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
"SplitTensorGetItem_Module_basic",
|
||||||
"SplitTensorLastSmallerModule_basic",
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
@ -2348,6 +2376,13 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
}
|
}
|
||||||
) - {
|
) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
|
"ChunkListUnpack_Module_basic",
|
||||||
|
"SplitTensorGetItem_Module_basic",
|
||||||
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
"SplitTensorNegativeDimModule_basic",
|
||||||
|
"SplitWithSizesListUnpackModule_basic",
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
@ -2588,6 +2623,7 @@ ONNX_XFAIL_SET = {
|
||||||
"SliceCopyNegative_Module_basic",
|
"SliceCopyNegative_Module_basic",
|
||||||
"SliceCopyNonZeroDim_Module_basic",
|
"SliceCopyNonZeroDim_Module_basic",
|
||||||
"SliceCopy_Module_basic",
|
"SliceCopy_Module_basic",
|
||||||
|
"SliceStaticComplexInputModule_basic",
|
||||||
"StdCorrectionLargeInputModule_basic",
|
"StdCorrectionLargeInputModule_basic",
|
||||||
"TupleModule_basic",
|
"TupleModule_basic",
|
||||||
"VarCorrectionLargeInputModule_basic",
|
"VarCorrectionLargeInputModule_basic",
|
||||||
|
@ -2757,6 +2793,7 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseExpm1IntModule_basic",
|
"ElementwiseExpm1IntModule_basic",
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseFmodTensor_Int_basic",
|
"ElementwiseFmodTensor_Int_basic",
|
||||||
|
"ElementwiseCreateComplexModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
"ElementwiseOrTensorModule_basic",
|
"ElementwiseOrTensorModule_basic",
|
||||||
|
@ -3071,7 +3108,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ScatterReduceIntMaxModuleIncludeSelf",
|
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||||
"ScatterReduceIntMinModuleIncludeSelf",
|
"ScatterReduceIntMinModuleIncludeSelf",
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
"ScatterAddStaticModule_basic",
|
|
||||||
# Failure - onnx_lowering: onnx.ScatterND
|
# Failure - onnx_lowering: onnx.ScatterND
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DIntAccumulateModule_basic",
|
"IndexPut1DIntAccumulateModule_basic",
|
||||||
|
@ -3107,6 +3143,10 @@ ONNX_XFAIL_SET = {
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
|
"CumprodInputDtypeInt32Module_basic",
|
||||||
|
"CumprodStaticModule_basic",
|
||||||
|
"CumprodStaticNegativeDimModule_basic",
|
||||||
"ElementwiseAcosIntModule_basic",
|
"ElementwiseAcosIntModule_basic",
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
|
@ -3132,6 +3172,11 @@ ONNX_XFAIL_SET = {
|
||||||
"GCDBatchedModule_I32",
|
"GCDBatchedModule_I32",
|
||||||
"GCDDynamicModule_I32",
|
"GCDDynamicModule_I32",
|
||||||
"GCDModule_I32",
|
"GCDModule_I32",
|
||||||
|
"Unfold_Module_Rank_4",
|
||||||
|
"Unfold_Module_Rank_Zero_basic",
|
||||||
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
|
"Unfold_Module_Dynamic_basic",
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||||
|
@ -3170,6 +3215,18 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
|
||||||
|
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||||
|
"ElementwiseCreateComplexModule_basic",
|
||||||
|
"ElementwiseTanIntModule_basic",
|
||||||
|
"ElementwiseTanModule_basic",
|
||||||
|
}
|
||||||
|
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
|
||||||
|
"ElementwiseCreateComplexModule_basic",
|
||||||
|
"ElementwiseTanIntModule_basic",
|
||||||
|
"ElementwiseTanModule_basic",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
"FakeQuantizePerTensorAffineModule_basic",
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
|
@ -3197,17 +3254,30 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"ArangeZeroElementOutputModule_basic",
|
||||||
|
"NumpyTRank0Module_basic",
|
||||||
|
"Permute0RankModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
|
"SliceStartEqEndModule_basic",
|
||||||
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
|
"ChunkListUnpack_Module_basic",
|
||||||
|
"SplitTensorGetItem_Module_basic",
|
||||||
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
"SplitTensorNegativeDimModule_basic",
|
||||||
|
"SplitWithSizesListUnpackModule_basic",
|
||||||
|
"SplitWithSizes_Module_basic",
|
||||||
|
"ElementwiseCreateComplexModule_basic",
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenPolarFloatModule_basic",
|
"AtenPolarFloatModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
"HstackBasicIntFloatModule_basic",
|
"HstackBasicIntFloatModule_basic",
|
||||||
"HstackBasicIntModule_basic",
|
"HstackBasicIntModule_basic",
|
||||||
"Rot90BasicModule_basic",
|
|
||||||
"Rot90DynamicDimsModule_basic",
|
|
||||||
"Rot90MultipleRotationsModule_basic",
|
|
||||||
"Rot90NegativeEvenRotationsModule_basic",
|
|
||||||
"Rot90NegativeOddRotationsModule_basic",
|
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
@ -3220,14 +3290,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"Conv_Transpose2dStaticModule_basic",
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
"Conv_Transpose3dModule_basic",
|
"Conv_Transpose3dModule_basic",
|
||||||
"Conv_Transpose3dStaticModule_basic",
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
"ElementwiseRreluEvalModule_basic",
|
"ElementwiseRreluEvalModule_basic",
|
||||||
"ElementwiseRreluEvalStaticModule_basic",
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
"ElementwiseRreluTrainModule_basic",
|
"ElementwiseRreluTrainModule_basic",
|
||||||
"ElementwiseRreluTrainStaticModule_basic",
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
|
||||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpool3dModulePad0_basic",
|
||||||
|
@ -3294,12 +3362,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenIntTensorCharDtypeModule_basic",
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
"AtenItemFpOpModule_basic",
|
"AtenItemFpOpModule_basic",
|
||||||
"AtenItemIntOpModule_basic",
|
"AtenItemIntOpModule_basic",
|
||||||
"AtenLinalgCrossBroadcast_basic",
|
|
||||||
"AtenLinalgCrossCustomDim_basic",
|
|
||||||
"AtenLinalgCrossDynamic_basic",
|
|
||||||
"AtenLinalgCrossFloat_basic",
|
|
||||||
"AtenLinalgCrossInt_basic",
|
|
||||||
"AtenLinalgCrossNegativeDim_basic",
|
|
||||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||||
"AtenMatmulQMixedSigni8_basic",
|
"AtenMatmulQMixedSigni8_basic",
|
||||||
"AtenMatmulQint8MV_basic",
|
"AtenMatmulQint8MV_basic",
|
||||||
|
@ -3312,8 +3374,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenMmQuint8_basic",
|
"AtenMmQuint8_basic",
|
||||||
"AtenRealView128Module_basic",
|
"AtenRealView128Module_basic",
|
||||||
"AtenRealView64Module_basic",
|
"AtenRealView64Module_basic",
|
||||||
"AtenRoundFloatHalfToEvenModule_basic",
|
|
||||||
"AtenRoundFloatModule_basic",
|
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
|
@ -3355,6 +3415,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
"Conv2dQInt8Module_depthwise",
|
"Conv2dQInt8Module_depthwise",
|
||||||
"Conv2dQInt8Module_grouped",
|
"Conv2dQInt8Module_grouped",
|
||||||
|
@ -3383,18 +3444,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
|
"CumprodInputDtypeInt32Module_basic",
|
||||||
|
"CumprodStaticModule_basic",
|
||||||
|
"CumprodStaticNegativeDimModule_basic",
|
||||||
"DeformConv2D_basic",
|
"DeformConv2D_basic",
|
||||||
"DeterminantBatchedModule_F32",
|
"DeterminantBatchedModule_F32",
|
||||||
"DeterminantDynamicModule_F32",
|
"DeterminantDynamicModule_F32",
|
||||||
"DeterminantModule_F32",
|
"DeterminantModule_F32",
|
||||||
"DiagonalModule_basic",
|
|
||||||
"DiagonalModule_nonsquare",
|
|
||||||
"DiagonalModule_transposed",
|
|
||||||
"DiagonalModule_with_dims",
|
|
||||||
"DiagonalModule_with_dims_and_offset",
|
|
||||||
"DiagonalModule_with_negative_dims",
|
|
||||||
"DiagonalModule_with_offset",
|
|
||||||
"DiagonalWithStaticShapeModule_basic",
|
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"DropoutTrainModule_basic",
|
"DropoutTrainModule_basic",
|
||||||
|
@ -3478,20 +3535,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"ExponentialModule_basic",
|
"ExponentialModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
|
||||||
"FakeQuantizePerTensorAffineModule_basic",
|
|
||||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
|
||||||
"Fill_TensorFloat32WithFloat32_basic",
|
|
||||||
"Fill_TensorFloat32WithFloat64_basic",
|
|
||||||
"Fill_TensorFloat32WithInt64_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat32_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat64_basic",
|
|
||||||
"Fill_TensorFloat64WithInt64Static_basic",
|
|
||||||
"Fill_TensorFloat64WithInt64_basic",
|
|
||||||
"FlipModuleStaticShape_basic",
|
|
||||||
"FlipModule_basic",
|
|
||||||
"FlipNegativeIndexModule_basic",
|
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
"FullLikeModuleInt2D_basic",
|
"FullLikeModuleInt2D_basic",
|
||||||
"FullLikeModuleInt3D_basic",
|
"FullLikeModuleInt3D_basic",
|
||||||
|
@ -3547,15 +3590,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImplIndexWithNoneModule_basic",
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
"IndexSelectDynamicIndexSizeModule_basic",
|
|
||||||
"IndexSelectDynamicInputSizeModule_basic",
|
|
||||||
"IndexSelectDynamicModulebasic",
|
|
||||||
"IndexSelectNegativeDimModule_basic",
|
|
||||||
"IndexSelectRank0IdxModule_basic",
|
"IndexSelectRank0IdxModule_basic",
|
||||||
"IndexSelectSingleIdxModule_basic",
|
|
||||||
"IndexSelectTwoIdxModule_basic",
|
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
|
||||||
"IndexSelectWholeTensorModule_basic",
|
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
|
@ -3753,6 +3788,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"SignAndLogarithmOfDeterminantModule_F32",
|
"SignAndLogarithmOfDeterminantModule_F32",
|
||||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||||
|
"SliceStaticComplexInputModule_basic",
|
||||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||||
"SliceCopyNegative_Module_basic",
|
"SliceCopyNegative_Module_basic",
|
||||||
"SliceCopyNonZeroDim_Module_basic",
|
"SliceCopyNonZeroDim_Module_basic",
|
||||||
|
@ -3808,11 +3844,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ToCopyWithDTypeModule_basic",
|
"ToCopyWithDTypeModule_basic",
|
||||||
"TorchPrimLoopForLikeModule_basic",
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
"TorchPrimLoopWhileLikeModule_basic",
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
"TraceModule_basic",
|
|
||||||
"TraceModule_empty",
|
"TraceModule_empty",
|
||||||
"TraceModule_nonsquare",
|
|
||||||
"TraceSignedIntModule_basic",
|
|
||||||
"TraceUnsignedIntModule_basic",
|
|
||||||
"TraceUnsignedIntModule_empty",
|
"TraceUnsignedIntModule_empty",
|
||||||
"TypeConversionI1ToF64Module_basic",
|
"TypeConversionI1ToF64Module_basic",
|
||||||
"TypeConversionI1ToI32Module_basic",
|
"TypeConversionI1ToI32Module_basic",
|
||||||
|
@ -3833,9 +3865,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"VarMeanUnbiasedModule_basic",
|
"VarMeanUnbiasedModule_basic",
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"ZeroFloat32Module_basic",
|
"VisionTransformerModule_basic",
|
||||||
"ZeroInt32Module_basic",
|
|
||||||
"ZeroInt64Module_basic",
|
|
||||||
"ZerosLikeModule_falsePinMemory",
|
"ZerosLikeModule_falsePinMemory",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3848,6 +3878,15 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"ArangeZeroElementOutputModule_basic",
|
||||||
|
"LinspaceEmptyModule_basic",
|
||||||
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
|
"TrilIndicesAllZerosModule_basic",
|
||||||
|
"TriuIndicesAllZerosModule_basic",
|
||||||
|
"ElementwiseCreateComplexModule_basic",
|
||||||
|
"ReduceAllDimFloatModule_basic",
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
|
@ -3877,7 +3916,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"Conv_Transpose2dStaticModule_basic",
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
"Conv_Transpose3dModule_basic",
|
"Conv_Transpose3dModule_basic",
|
||||||
"Conv_Transpose3dStaticModule_basic",
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
"ElementwiseFmaxModule_basic",
|
"ElementwiseFmaxModule_basic",
|
||||||
"ElementwiseFminModule_basic",
|
"ElementwiseFminModule_basic",
|
||||||
|
@ -4010,8 +4048,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenRealView128Module_basic",
|
"AtenRealView128Module_basic",
|
||||||
"AtenRealView64Module_basic",
|
"AtenRealView64Module_basic",
|
||||||
"AtenRoundFloatHalfToEvenModule_basic",
|
|
||||||
"AtenRoundFloatModule_basic",
|
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
|
@ -4055,8 +4091,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"BucketizeTensorFloatModule_basic",
|
"BucketizeTensorFloatModule_basic",
|
||||||
"BucketizeTensorModule_basic",
|
"BucketizeTensorModule_basic",
|
||||||
"BucketizeTensorOutInt32RightModule_basic",
|
"BucketizeTensorOutInt32RightModule_basic",
|
||||||
"BucketizeTensorStaticFloatModule_basic",
|
|
||||||
"BucketizeTensorStaticModule_basic",
|
|
||||||
"CeilFloatModule_basic",
|
"CeilFloatModule_basic",
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
|
@ -4075,6 +4109,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
|
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Conv2dBiasNoPaddingModule_basic",
|
"Conv2dBiasNoPaddingModule_basic",
|
||||||
"Conv2dModule_basic",
|
"Conv2dModule_basic",
|
||||||
"Conv2dNoPaddingModule_basic",
|
"Conv2dNoPaddingModule_basic",
|
||||||
|
@ -4115,6 +4150,10 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
"CumsumStaticNegativeDimModule_basic",
|
"CumsumStaticNegativeDimModule_basic",
|
||||||
|
"CumprodModule_basic",
|
||||||
|
"CumprodInputDtypeInt32Module_basic",
|
||||||
|
"CumprodStaticModule_basic",
|
||||||
|
"CumprodStaticNegativeDimModule_basic",
|
||||||
"DeformConv2D_basic",
|
"DeformConv2D_basic",
|
||||||
"DeterminantModule_F32",
|
"DeterminantModule_F32",
|
||||||
"DeterminantBatchedModule_F32",
|
"DeterminantBatchedModule_F32",
|
||||||
|
@ -4265,7 +4304,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseWhereSelfModule_basic",
|
"ElementwiseWhereSelfModule_basic",
|
||||||
"EmbeddingModule1DIndices_basic",
|
"EmbeddingModule1DIndices_basic",
|
||||||
"EmbeddingModuleF16_basic",
|
"EmbeddingModuleF16_basic",
|
||||||
"EmbeddingModuleI32Static_basic",
|
|
||||||
"EmbeddingModuleI32_basic",
|
"EmbeddingModuleI32_basic",
|
||||||
"EmbeddingModuleI64_basic",
|
"EmbeddingModuleI64_basic",
|
||||||
"EmptyLikeMemoryFormatModule_basic",
|
"EmptyLikeMemoryFormatModule_basic",
|
||||||
|
@ -4359,12 +4397,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"IndexSelectDynamicIndexSizeModule_basic",
|
"IndexSelectDynamicIndexSizeModule_basic",
|
||||||
"IndexSelectDynamicInputSizeModule_basic",
|
"IndexSelectDynamicInputSizeModule_basic",
|
||||||
"IndexSelectDynamicModulebasic",
|
"IndexSelectDynamicModulebasic",
|
||||||
"IndexSelectNegativeDimModule_basic",
|
|
||||||
"IndexSelectRank0IdxModule_basic",
|
|
||||||
"IndexSelectSingleIdxModule_basic",
|
|
||||||
"IndexSelectTwoIdxModule_basic",
|
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
|
||||||
"IndexSelectWholeTensorModule_basic",
|
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||||
"IndexTensorHackedTwinModule3dInput_basic",
|
"IndexTensorHackedTwinModule3dInput_basic",
|
||||||
|
@ -4382,10 +4414,8 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"IndexTensorMultiInputOneDim_basic",
|
"IndexTensorMultiInputOneDim_basic",
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
"IndexTensorMultiInput_basic",
|
"IndexTensorMultiInput_basic",
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
|
||||||
"IndexTensorSelectDimModule_basic",
|
"IndexTensorSelectDimModule_basic",
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
"IndexTensorStaticModule_basic",
|
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
|
@ -4684,7 +4714,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
"ScatterValueIntModule_basic",
|
"ScatterValueIntModule_basic",
|
||||||
"SelectIntModule_basic",
|
"SelectIntModule_basic",
|
||||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
|
||||||
"SelectScattertModule_basic",
|
"SelectScattertModule_basic",
|
||||||
"SelectScattertStaticModule_basic",
|
"SelectScattertStaticModule_basic",
|
||||||
"SignAndLogarithmOfDeterminantModule_F32",
|
"SignAndLogarithmOfDeterminantModule_F32",
|
||||||
|
@ -4696,6 +4725,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"SliceCopy_Module_basic",
|
"SliceCopy_Module_basic",
|
||||||
"SliceEndSleStartModule_basic",
|
"SliceEndSleStartModule_basic",
|
||||||
"SliceModule_basic",
|
"SliceModule_basic",
|
||||||
|
"SliceStaticComplexInputModule_basic",
|
||||||
"SliceNegIdxModule_basic",
|
"SliceNegIdxModule_basic",
|
||||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||||
|
|
|
@ -1445,6 +1445,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b
|
||||||
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -2001,6 +2004,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int =
|
||||||
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
||||||
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
|
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
|
||||||
|
|
||||||
|
def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]:
|
||||||
|
scalar_shape: List[int] = []
|
||||||
|
if reduction == 0:
|
||||||
|
result_shape = upstream_shape_functions._copy(self)
|
||||||
|
else:
|
||||||
|
result_shape = scalar_shape
|
||||||
|
return result_shape
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||||
])
|
])
|
||||||
|
@ -2937,6 +2948,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
|
||||||
return torch.int64
|
return torch.int64
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
|
||||||
|
def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
|
||||||
|
if dtype is not None:
|
||||||
|
return dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
if is_integer_dtype(self_dtype):
|
||||||
|
return torch.int64
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
@ -4954,6 +4977,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
|
||||||
return dtype
|
return dtype
|
||||||
return aten〇std〡dtype(self_rank_dtype)
|
return aten〇std〡dtype(self_rank_dtype)
|
||||||
|
|
||||||
|
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(
|
_check_tensors_with_the_same_dtype(
|
||||||
tensor_shapes=[(3,3)],
|
tensor_shapes=[(3,3)],
|
||||||
|
@ -5543,7 +5570,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int,
|
||||||
return torch.qint8
|
return torch.qint8
|
||||||
return torch.qint32
|
return torch.qint32
|
||||||
|
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero.
|
||||||
|
Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0.
|
||||||
|
Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case.
|
||||||
|
Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case.
|
||||||
|
Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension.
|
||||||
|
Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension.
|
||||||
|
])
|
||||||
|
def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]:
|
||||||
|
ndim = len(self)
|
||||||
|
|
||||||
|
# Rank zero tensor
|
||||||
|
if ndim == 0:
|
||||||
|
assert dimension == 0, f"dimension out of range of {ndim}"
|
||||||
|
assert size <= 1, "size must be less than or equal to 1"
|
||||||
|
return [size]
|
||||||
|
|
||||||
|
dim = dimension
|
||||||
|
if dim < 0:
|
||||||
|
dim += ndim
|
||||||
|
|
||||||
|
assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}"
|
||||||
|
|
||||||
|
size_dim = self[dim]
|
||||||
|
assert size <= size_dim, f"size must be less than or equal to {size_dim}"
|
||||||
|
|
||||||
|
num_blocks = (size_dim - size) // step + 1
|
||||||
|
|
||||||
|
out = upstream_shape_functions._copy(self)
|
||||||
|
out[dim] = num_blocks
|
||||||
|
out.append(size)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1)
|
||||||
|
)
|
||||||
|
def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::rad2deg : (Tensor) -> (Tensor)")
|
emit("aten::rad2deg : (Tensor) -> (Tensor)")
|
||||||
|
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::real : (Tensor) -> (Tensor)")
|
emit("aten::real : (Tensor) -> (Tensor)")
|
||||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||||
|
@ -617,6 +618,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
|
emit(
|
||||||
|
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||||
|
)
|
||||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
|
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
|
@ -740,6 +744,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
||||||
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
@ -986,6 +993,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
|
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||||
|
emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
|
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
|
||||||
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
|
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
|
||||||
|
|
|
@ -42,7 +42,7 @@ def import_onnx(contents):
|
||||||
# Import the ONNX model proto from the file contents:
|
# Import the ONNX model proto from the file contents:
|
||||||
raw_model = onnx.load_from_string(contents)
|
raw_model = onnx.load_from_string(contents)
|
||||||
# since it does not affect current e2e tests, data_prop is left false here
|
# since it does not affect current e2e tests, data_prop is left false here
|
||||||
model_proto = onnx.shape_inference.infer_shapes(raw_model)
|
model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)
|
||||||
|
|
||||||
# Import the ONNX module into an MLIR module:
|
# Import the ONNX module into an MLIR module:
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
|
@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
ones = torch.ones([1], dtype=torch.int32)
|
||||||
|
return torch.ops.aten.cumprod(val, ones.item())
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodModule())
|
||||||
|
def CumprodModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodStaticModule())
|
||||||
|
def CumprodStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodStaticNegativeDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
|
||||||
|
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodInputDtypeInt32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
|
||||||
|
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 7, 4).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenToDeviceModule(torch.nn.Module):
|
class AtenToDeviceModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 4, 6], torch.float32, True),
|
||||||
|
([4, 1, 3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten.conv1d(
|
||||||
|
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
|
||||||
|
)
|
||||||
|
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||||
|
inputVec = tu.rand(2, 4, 6)
|
||||||
|
weight = torch.randn(4, 1, 3)
|
||||||
|
module.forward(inputVec, weight)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dModule(torch.nn.Module):
|
class Conv2dModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseCreateComplexModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.complex(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
|
||||||
|
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(4, high=10).type(torch.float32),
|
||||||
|
tu.randint(4, high=10).type(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1783,6 +1783,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 512, 10))
|
module.forward(tu.rand(1, 512, 10))
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.amp1d(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
|
||||||
|
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 512, 7))
|
||||||
|
|
||||||
|
|
||||||
# AdaptiveMaxPool2d
|
# AdaptiveMaxPool2d
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceAllDimFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.all(a, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
|
||||||
|
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@ -2294,6 +2314,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
|
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([8, 2], torch.float32, True),
|
||||||
|
([8, 2], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, input, target):
|
||||||
|
return torch.ops.aten.binary_cross_entropy_with_logits(
|
||||||
|
input, target, reduction=0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule())
|
||||||
|
def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(8, 2), tu.rand(8, 2))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ViewDtypeStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([12, 1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
res = a.view(torch.int8)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
|
||||||
|
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(12, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1648,3 +1672,103 @@ class Rot90NegativeEvenRotationsModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
|
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
|
||||||
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(6, 5, 1, 7, 3))
|
module.forward(tu.rand(6, 5, 1, 7, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class Unfold_Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([6, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.unfold(0, 2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Unfold_Module())
|
||||||
|
def Unfold_Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class Unfold_Module_Negative_Dim(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([6, 4, 4, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.unfold(-1, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim())
|
||||||
|
def Unfold_Module_Rank_4(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 4, 4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class Unfold_Module_Rank_Zero(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.unfold(0, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||||
|
def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.unfold(0, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||||
|
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
class Unfold_Module_Dynamic(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.unfold(1, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
|
||||||
|
def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 4, 4, 4))
|
||||||
|
|
|
@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SliceStaticComplexInputModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([6, 4, 7], torch.complex64, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x[0:5:1, 1:3:1, 2:4:1]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
|
||||||
|
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -223,13 +223,13 @@ INSTALL_REQUIRES = [
|
||||||
EXT_MODULES = [
|
EXT_MODULES = [
|
||||||
CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
|
CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
|
||||||
]
|
]
|
||||||
NAME = "torch-mlir-core"
|
NAME = "torch-mlir"
|
||||||
|
|
||||||
# If building PyTorch extensions, customize.
|
# If building PyTorch extensions, customize.
|
||||||
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
|
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
NAME = "torch-mlir"
|
NAME = "torch-mlir-ext"
|
||||||
INSTALL_REQUIRES.extend(
|
INSTALL_REQUIRES.extend(
|
||||||
[
|
[
|
||||||
f"torch=={torch.__version__}".split("+", 1)[0],
|
f"torch=={torch.__version__}".split("+", 1)[0],
|
||||||
|
|
|
@ -16,10 +16,71 @@
|
||||||
// CHECK-DAG: torch.prim.Loop.condition
|
// CHECK-DAG: torch.prim.Loop.condition
|
||||||
// CHECK-DAG: }
|
// CHECK-DAG: }
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
module {
|
|
||||||
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
|
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
|
||||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
|
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias(
|
||||||
|
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>,
|
||||||
|
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>,
|
||||||
|
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>,
|
||||||
|
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>)
|
||||||
|
// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) {
|
||||||
|
// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>):
|
||||||
|
// CHECK-DAG: torch.aten.select.int
|
||||||
|
// CHECK-DAG: torch.aten.linear
|
||||||
|
// CHECK-DAG: torch.aten.sigmoid
|
||||||
|
// CHECK-DAG: torch.aten.tanh
|
||||||
|
// CHECK-DAG: torch.prim.Loop.condition
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: torch.aten.flip
|
||||||
|
// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) {
|
||||||
|
// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>):
|
||||||
|
// CHECK-DAG: torch.aten.select.int
|
||||||
|
// CHECK-DAG: torch.aten.linear
|
||||||
|
// CHECK-DAG: torch.aten.sigmoid
|
||||||
|
// CHECK-DAG: torch.aten.tanh
|
||||||
|
// CHECK-DAG: torch.prim.Loop.condition
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: torch.aten.flip
|
||||||
|
// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)
|
||||||
|
return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs(
|
||||||
|
// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>,
|
||||||
|
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>,
|
||||||
|
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>)
|
||||||
|
// CHECK: torch.aten.transpose.int
|
||||||
|
// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
|
||||||
|
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>):
|
||||||
|
// CHECK-DAG: torch.aten.select.int
|
||||||
|
// CHECK-DAG: torch.aten.linear
|
||||||
|
// CHECK-DAG: torch.aten.sigmoid
|
||||||
|
// CHECK-DAG: torch.aten.tanh
|
||||||
|
// CHECK-DAG: torch.prim.Loop.condition
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-DAG: torch.aten.transpose.int
|
||||||
|
// CHECK-DAG: torch.aten.transpose.int
|
||||||
|
// CHECK-DAG: torch.aten.transpose.int
|
||||||
|
// CHECK-DAG: torch.aten.transpose.int
|
||||||
|
// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
|
||||||
|
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
|
||||||
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||||
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
|
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
|
||||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||||
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
|
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
|
||||||
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
// CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
|
||||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
|
||||||
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
||||||
return %0 : !torch.vtensor<[3,4],f32>
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
|
||||||
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
|
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
|
||||||
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
|
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
|
||||||
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
|
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
|
||||||
|
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
|
||||||
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
|
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
|
||||||
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
|
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
|
||||||
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
// CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
|
||||||
|
// CHECK-NEXT: torch.runtime.assert %[[GE]]
|
||||||
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
|
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
|
||||||
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
|
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
|
||||||
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
|
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
|
||||||
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
|
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
|
||||||
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
|
|
||||||
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
|
|
||||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
|
|
||||||
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
|
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
|
||||||
// CHECK: return %[[EXPAND]]
|
// CHECK: return %[[EXPAND]]
|
||||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
||||||
|
|
|
@ -262,14 +262,15 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
|
||||||
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
|
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
|
||||||
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
|
||||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
|
||||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
|
||||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
|
||||||
// CHECK: %[[STR:.*]] = torch.constant.str "add"
|
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
|
||||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
|
||||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||||
return %0 : !torch.vtensor<[1,5],f32>
|
return %0 : !torch.vtensor<[1,5],f32>
|
||||||
}
|
}
|
||||||
|
@ -295,14 +296,15 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
|
||||||
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
|
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
|
||||||
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
|
||||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
|
||||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
|
||||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
|
||||||
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
|
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
|
||||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
|
||||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||||
return %0 : !torch.vtensor<[1,5],f32>
|
return %0 : !torch.vtensor<[1,5],f32>
|
||||||
}
|
}
|
||||||
|
@ -2833,6 +2835,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>)
|
||||||
return %0 : !torch.vtensor<[1],si64>
|
return %0 : !torch.vtensor<[1],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_shape_scalar
|
||||||
|
func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
|
||||||
|
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64>
|
||||||
|
// CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
|
||||||
|
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64>
|
||||||
|
return %0: !torch.vtensor<[?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
|
||||||
|
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
|
||||||
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
|
||||||
|
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
|
||||||
|
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %1 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
|
@ -696,8 +696,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
|
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||||
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
|
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32>
|
||||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
|
||||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -890,15 +890,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.max.dim$basic(
|
// CHECK-LABEL: @torch.aten.max.dim$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
|
||||||
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
// CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
// CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
// CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
// CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
// CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
// CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
// CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||||
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||||
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||||
|
@ -1378,16 +1378,16 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
// CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
// CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
// CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
// CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
// CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
// CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||||
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||||
|
@ -1859,3 +1859,142 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
|
||||||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
return %0: !torch.vtensor<[?,?],si32>
|
return %0: !torch.vtensor<[?,?],si32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
|
||||||
|
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
|
||||||
|
%dim1 = torch.constant.int 1
|
||||||
|
%dim2 = torch.constant.int 0
|
||||||
|
%offset = torch.constant.int -2
|
||||||
|
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
|
||||||
|
return %0 : !torch.vtensor<[5,6,2],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
|
||||||
|
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,5,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.flip(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.round(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -4,17 +4,17 @@
|
||||||
// CHECK-LABEL: func.func @scan_1d_inclusive(
|
// CHECK-LABEL: func.func @scan_1d_inclusive(
|
||||||
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
||||||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||||
|
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||||
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||||
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
||||||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
|
||||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
|
||||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||||
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
|
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
|
||||||
|
@ -32,8 +32,10 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
||||||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||||
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
|
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
|
||||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||||
|
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||||
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
|
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
|
||||||
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||||
|
@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
||||||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
|
||||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
|
||||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||||
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
|
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
|
||||||
|
@ -62,14 +62,14 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
||||||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||||
|
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||||
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
|
||||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||||
func.func @scatter_update_scalar_1D(
|
func.func @scatter_update_scalar_1D(
|
||||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||||
|
@ -90,7 +90,8 @@ func.func @scatter_update_scalar_1D(
|
||||||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||||
|
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||||
|
@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D(
|
||||||
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
|
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
|
||||||
// CHECK: tm_tensor.yield %[[ADD]] : i32
|
// CHECK: tm_tensor.yield %[[ADD]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
|
||||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||||
func.func @scatter_add_scalar_1D(
|
func.func @scatter_add_scalar_1D(
|
||||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||||
|
|
|
@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
|
||||||
|
|
||||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||||
// expected-error @+1 {{op requires non-empty shapeSymbols}}
|
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
|
||||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||||
return %arg0 : !torch.vtensor<[?],f32>
|
return %arg0 : !torch.vtensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Verifier should not fail here since the op does not require shapeSymbols.
|
||||||
|
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
|
||||||
|
return %arg0 : !torch.vtensor<[?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
||||||
|
|
|
@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
|
||||||
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
|
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
|
||||||
return %slice : !torch.vtensor<[2],si32>
|
return %slice : !torch.vtensor<[2],si32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @view_as_flatten_static
|
||||||
|
func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> {
|
||||||
|
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||||
|
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
|
||||||
|
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
|
||||||
|
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
|
||||||
|
%int1024 = torch.constant.int 1024
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||||
|
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||||
|
%2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list<int> -> !torch.vtensor<[?,?,1024],f32>
|
||||||
|
return %3 : !torch.vtensor<[?,?,1024],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @view_as_unflatten_static
|
||||||
|
func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> {
|
||||||
|
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||||
|
// CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
|
||||||
|
// CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
|
||||||
|
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
|
||||||
|
%int16 = torch.constant.int 16
|
||||||
|
%int64 = torch.constant.int 64
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
|
||||||
|
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
|
||||||
|
%2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
|
||||||
|
return %3 : !torch.vtensor<[?,?,16,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @view_as_flatten_dynamic
|
||||||
|
func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||||
|
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
|
||||||
|
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
|
||||||
|
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
|
||||||
|
return %3 : !torch.vtensor<[?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @unsqueeze_squeeze_combo
|
||||||
|
func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int {
|
||||||
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
|
// CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: return %0 : !torch.int
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
|
||||||
|
%12 = torch.aten.cat %11, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
|
||||||
|
%13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
return %14 : !torch.int
|
||||||
|
}
|
||||||
|
|
|
@ -34,5 +34,5 @@ def test_enable_ir_printing():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
|
# CHECK: // -----// IR Dump After Inliner (inline)
|
||||||
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|
||||||
|
|
Loading…
Reference in New Issue