mirror of https://github.com/llvm/torch-mlir
Merge branch 'main' into lower_torch_aten_gcd_to_linalg_and_scf
commit
ee7f6ee9fd
|
@ -50,7 +50,7 @@ TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}"
|
|||
# Location to store Release wheels
|
||||
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
|
||||
# What "packages to build"
|
||||
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}"
|
||||
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-ext}"
|
||||
# Use pre-built Pytorch
|
||||
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
||||
# Skip running tests if you want quick iteration
|
||||
|
@ -83,12 +83,12 @@ function run_on_host() {
|
|||
fi
|
||||
mkdir -p "${TM_OUTPUT_DIR}"
|
||||
case "$package" in
|
||||
torch-mlir)
|
||||
torch-mlir-ext)
|
||||
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
||||
export USERID=0
|
||||
export GROUPID=0
|
||||
;;
|
||||
torch-mlir-core)
|
||||
torch-mlir)
|
||||
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
||||
export USERID=0
|
||||
export GROUPID=0
|
||||
|
@ -158,22 +158,22 @@ function run_in_docker() {
|
|||
export PATH=$python_dir/bin:$orig_path
|
||||
echo ":::: Python version $(python3 --version)"
|
||||
case "$package" in
|
||||
torch-mlir)
|
||||
clean_wheels torch_mlir "$python_version"
|
||||
build_torch_mlir "$TM_TORCH_VERSION"
|
||||
torch-mlir-ext)
|
||||
clean_wheels torch_mlir_ext "$python_version"
|
||||
build_torch_mlir_ext "$TM_TORCH_VERSION"
|
||||
|
||||
# Disable audit wheel until we can fix ODR torch issues. See
|
||||
# https://github.com/llvm/torch-mlir/issues/1709
|
||||
#
|
||||
#run_audit_wheel torch_mlir "$python_version"
|
||||
#run_audit_wheel torch_mlir_ext "$python_version"
|
||||
|
||||
clean_build torch_mlir "$python_version"
|
||||
clean_build torch_mlir_ext "$python_version"
|
||||
;;
|
||||
torch-mlir-core)
|
||||
clean_wheels torch_mlir_core "$python_version"
|
||||
build_torch_mlir_core
|
||||
run_audit_wheel torch_mlir_core "$python_version"
|
||||
clean_build torch_mlir_core "$python_version"
|
||||
torch-mlir)
|
||||
clean_wheels torch_mlir "$python_version"
|
||||
build_torch_mlir
|
||||
run_audit_wheel torch_mlir "$python_version"
|
||||
clean_build torch_mlir "$python_version"
|
||||
;;
|
||||
out-of-tree)
|
||||
setup_venv "$python_version" "$TM_TORCH_VERSION"
|
||||
|
@ -431,7 +431,7 @@ function clean_build() {
|
|||
rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch
|
||||
}
|
||||
|
||||
function build_torch_mlir() {
|
||||
function build_torch_mlir_ext() {
|
||||
# Disable LTC build for releases
|
||||
export TORCH_MLIR_ENABLE_LTC=0
|
||||
local torch_version="$1"
|
||||
|
@ -470,7 +470,9 @@ function run_audit_wheel() {
|
|||
rm "$generic_wheel"
|
||||
}
|
||||
|
||||
function build_torch_mlir_core() {
|
||||
function build_torch_mlir() {
|
||||
# Disable LTC build for releases
|
||||
export TORCH_MLIR_ENABLE_LTC=0
|
||||
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
|
|
|
@ -56,16 +56,16 @@ function run() {
|
|||
export PATH=$python_dir/bin:$orig_path
|
||||
echo ":::: Python version $(python3 --version)"
|
||||
case "$package" in
|
||||
torch-mlir-ext)
|
||||
clean_wheels torch_mlir_ext "$python_version"
|
||||
build_torch_mlir_ext torch_mlir_ext "$python_version"
|
||||
run_audit_wheel torch_mlir_ext "$python_version"
|
||||
;;
|
||||
torch-mlir)
|
||||
clean_wheels torch_mlir "$python_version"
|
||||
build_torch_mlir torch_mlir "$python_version"
|
||||
run_audit_wheel torch_mlir "$python_version"
|
||||
;;
|
||||
torch-mlir-core)
|
||||
clean_wheels torch_mlir_core "$python_version"
|
||||
build_torch_mlir_core torch_mlir_core "$python_version"
|
||||
run_audit_wheel torch_mlir_core "$python_version"
|
||||
;;
|
||||
*)
|
||||
echo "Unrecognized package '$package'"
|
||||
exit 1
|
||||
|
@ -75,7 +75,7 @@ function run() {
|
|||
done
|
||||
}
|
||||
|
||||
function build_torch_mlir() {
|
||||
function build_torch_mlir_ext() {
|
||||
local wheel_basename="$1"
|
||||
local python_version="$2"
|
||||
rm -rf "$output_dir"/build_venv
|
||||
|
@ -93,7 +93,7 @@ function build_torch_mlir() {
|
|||
rm -rf "$output_dir"/build_venv
|
||||
}
|
||||
|
||||
function build_torch_mlir_core() {
|
||||
function build_torch_mlir() {
|
||||
local wheel_basename="$1"
|
||||
local python_version="$2"
|
||||
rm -rf "$output_dir"/build_venv
|
||||
|
|
|
@ -14,7 +14,7 @@ While this is running, you can already setup the Python venv and dependencies in
|
|||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
```shell
|
||||
python -m venv mlir_venv
|
||||
python3 -m venv mlir_venv
|
||||
source mlir_venv/bin/activate
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit d418a03e01e6a31b51b0c9dd42ba46da6c47f89d
|
||||
Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29
|
|
@ -1 +1 @@
|
|||
Subproject commit c28d55e91b4a5daaff18a33ce7e9bbd0f171256a
|
||||
Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739
|
|
@ -34,6 +34,7 @@ struct OpBinder {
|
|||
Location getLoc() { return op->getLoc(); }
|
||||
|
||||
int getNumOperands() { return op->getNumOperands(); }
|
||||
int getNumResults() { return op->getNumResults(); }
|
||||
|
||||
// Operand matches of different arities.
|
||||
ParseResult tensorOperand(Value &value0) {
|
||||
|
@ -338,6 +339,31 @@ struct OpBinder {
|
|||
return failure();
|
||||
}
|
||||
|
||||
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
|
||||
StringRef nameSuffix,
|
||||
ArrayRef<float> defaults) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
name.append(nameSuffix);
|
||||
auto attr = op->getAttr(name);
|
||||
if (!attr) {
|
||||
values.append(defaults.begin(), defaults.end());
|
||||
return success();
|
||||
}
|
||||
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||
for (auto element : arrayAttr) {
|
||||
auto floatAttr = dyn_cast<FloatAttr>(element);
|
||||
if (!floatAttr)
|
||||
return failure();
|
||||
FloatType t = cast<FloatType>(floatAttr.getType());
|
||||
if (t.getWidth() != 32)
|
||||
return failure();
|
||||
values.push_back(floatAttr.getValue().convertToFloat());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
||||
StringRef nameSuffix) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
|
|
|
@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
|
|||
Location loc, SmallVector<int64_t> dimensions,
|
||||
Value input, Value &result);
|
||||
|
||||
// Flips an input tensor based on the values of axis list.
|
||||
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
|
||||
SmallVector<int64_t> axis);
|
||||
|
||||
} // namespace torch_to_linalg
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
|
|
|
@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
|
||||
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
Type elemTy);
|
||||
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
Type elemTy);
|
||||
|
||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
||||
|
||||
|
|
|
@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$real,
|
||||
AnyTorchTensorType:$imag
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenComplexOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
@ -7078,6 +7102,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result0,
|
||||
AnyTorchOptionalTensorType:$result1
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 2);
|
||||
}
|
||||
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -9195,6 +9248,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$target,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$pos_weight,
|
||||
Torch_IntType:$reduction
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -13637,6 +13717,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dimension,
|
||||
Torch_IntType:$size,
|
||||
Torch_IntType:$step
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenUnfoldOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return failure();
|
||||
|
||||
auto shapeSizes = shapeType.getSizes();
|
||||
int64_t dataRank = dataType.getSizes().size();
|
||||
ArrayRef<int64_t> dataShape = dataType.getSizes();
|
||||
int64_t dataRank = dataShape.size();
|
||||
int64_t shapeRank = shapeSizes.size();
|
||||
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
// we are using torch implementation Torch::AtenBroadcastToOp which
|
||||
// takes list of int
|
||||
for (int i = 0; i < shapeSizes[0]; i++) {
|
||||
// extract dim from shape
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, selectResultType, shape, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
Value selectDim = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), extract);
|
||||
|
||||
if (i + rankDifference >= 0) {
|
||||
// compute dim to pass to broadcast op. For non-broadcastable dims,
|
||||
// pass -1
|
||||
Value dim;
|
||||
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
|
||||
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
|
||||
// broadcasted
|
||||
// 2. we will explicitly disallow broadcasting dynamic dims that are
|
||||
// secretly 1.
|
||||
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
|
||||
// Assert dataShape[i + rankDiff] >= selectDim. If both are
|
||||
// constant, this should fold out.
|
||||
Value iv =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
|
||||
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), data, iv);
|
||||
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
|
||||
Value gtSelect =
|
||||
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
loc, gtSelect,
|
||||
rewriter.getStringAttr(
|
||||
"onnx.Expand input has a dim that is not statically 1; "
|
||||
"expected this dim >= dim provided shape."));
|
||||
} else {
|
||||
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
|
||||
// dataRank)
|
||||
// 2. selectDims which correspond to dataShape == 1 get included in
|
||||
// broadcast
|
||||
dim = selectDim;
|
||||
}
|
||||
|
||||
dimList.push_back(dim);
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
|
|
|
@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
// TODO: Implement max and min cases
|
||||
if (reduction == "mul") {
|
||||
reduction = "multiply";
|
||||
reduction = "prod";
|
||||
} else if (reduction == "max" || reduction == "min") {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "max/min reduction unsupported for scatter elements");
|
||||
} else if (reduction == "add") {
|
||||
reduction = "sum";
|
||||
}
|
||||
|
||||
Value cstStrReduction =
|
||||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
|
||||
Value cstTrue =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
|
||||
binder.op, resultType, data, constAxis, indices, updates,
|
||||
cstStrReduction);
|
||||
cstStrReduction, cstTrue);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
|
@ -1662,10 +1665,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
auto shapeType = Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
||||
resultType.getOptionalDtype());
|
||||
|
||||
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
|
||||
binder.getLoc(), shapeType, operand);
|
||||
|
||||
if (inputRank == 0) {
|
||||
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
|
||||
binder.op, resultType, shape);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (start == 0 && end == -1) {
|
||||
rewriter.replaceOp(binder.op, shape);
|
||||
return success();
|
||||
|
@ -1673,18 +1681,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
Value sv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(start));
|
||||
|
||||
Value ev = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(end));
|
||||
|
||||
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
|
||||
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
|
||||
|
||||
shape = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), resultType, shape, dim, sv, ev, step);
|
||||
|
||||
rewriter.replaceOp(binder.op, shape);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
|
||||
binder.op, resultType, shape, dim, sv, ev, step);
|
||||
return success();
|
||||
});
|
||||
|
||||
|
@ -4339,6 +4342,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
llvm::SmallVector<int64_t> ngram_counts;
|
||||
llvm::SmallVector<int64_t> ngram_indexes;
|
||||
llvm::SmallVector<int64_t> pool_int64s;
|
||||
llvm::SmallVector<float> weights;
|
||||
std::string mode;
|
||||
int64_t min_gram_length;
|
||||
int64_t max_gram_length;
|
||||
|
@ -4356,9 +4360,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
if (mode != "TF")
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"TF mode supported only");
|
||||
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
|
||||
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
|
||||
return failure();
|
||||
|
||||
if (pool_int64s.size() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "pool_int64s empty, only integers supported");
|
||||
|
@ -4584,9 +4589,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||
}
|
||||
count = skipLoop.getResult(0);
|
||||
// insert count "tf" into output
|
||||
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
||||
binder.getLoc(), count);
|
||||
if (mode == "IDF" || mode == "TFIDF") {
|
||||
// both IDF and TFIDF modes use weights
|
||||
float weight = weights[ngram_i];
|
||||
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(weight));
|
||||
|
||||
// TFIDF
|
||||
Value multiplier = countFloat;
|
||||
if (mode == "IDF") {
|
||||
// All the counts larger than 1 would be truncated to 1
|
||||
// and the i-th element in weights would be used to scale
|
||||
// (by multiplication) the count of the i-th n-gram in pool.
|
||||
|
||||
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
|
||||
binder.getLoc(), count);
|
||||
// compare intCount > 0
|
||||
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
|
||||
binder.getLoc(), intCount, zero);
|
||||
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), gtZeroCount);
|
||||
Value gtZeroCountFloat =
|
||||
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
|
||||
gtZeroCount);
|
||||
multiplier = gtZeroCountFloat;
|
||||
}
|
||||
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), multiplier, constWeight);
|
||||
}
|
||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
|
|
|
@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
std::string direction;
|
||||
|
||||
ValueTensorType yTy, Y_hType, Y_cType;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1) ||
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1) &&
|
||||
binder.tensorResultTypeAtIndex(Y_cType, 2)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"At least one outputs must be present");
|
||||
|
@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
|
||||
auto xTy = cast<ValueTensorType>(X.getType());
|
||||
auto wTy = cast<ValueTensorType>(W.getType());
|
||||
Value B;
|
||||
if (binder.tensorOperandAtIndex(B, 3)) {
|
||||
B = b.create<AtenZerosOp>(W.getType(), W);
|
||||
}
|
||||
|
||||
// TODO: add defaults for activation_alpha acticvation_beta attributes
|
||||
|
||||
llvm::SmallVector<std::string> activationsList;
|
||||
if (binder.stringArrayAttr(activationsList, "activations"))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Missing required attribute; activations");
|
||||
|
||||
LstmActivations activations;
|
||||
activations.f = "Sigmoid";
|
||||
activations.g = "Tanh";
|
||||
activations.h = "Tanh";
|
||||
if (activationsList.size() == 3) {
|
||||
activations.f = activationsList[0];
|
||||
activations.g = activationsList[1];
|
||||
activations.h = activationsList[2];
|
||||
} else if (activationsList.size() != 0) {
|
||||
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
||||
direction != "forward" && direction != "bidirectional")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "activations must be empty have 3 elements, but " +
|
||||
binder.op, "Unsupported direction attribute value. "
|
||||
"Only 'forward' / 'bidrectional' are supported but '" +
|
||||
direction + "' is provided.");
|
||||
int64_t num_directions = 1 + (direction == "bidirectional");
|
||||
bool isBidirectional = direction == "bidirectional";
|
||||
// There can be backward activations too
|
||||
// if backward -> look for 6 atcivations (what happens when only three?)
|
||||
|
||||
int64_t num_activations = activationsList.size();
|
||||
if (num_activations != 0 && num_activations != 3 && num_activations != 6) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "activations must either be empty (default), have 3 elements"
|
||||
" (forward) or, have 6 elements (bidirectional), but " +
|
||||
std::to_string(activationsList.size()) +
|
||||
" are provided.");
|
||||
}
|
||||
// TODO : Add checks, defaults and fails for inputs - sequence_lens, P and
|
||||
// attrs- clip, input_forget, layout
|
||||
|
||||
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
||||
direction != "forward")
|
||||
Value B;
|
||||
if (binder.tensorOperandAtIndex(B, 3)) {
|
||||
Value none = b.create<ConstantNoneOp>();
|
||||
Value cstHiddenx8 = b.create<ConstantIntOp>(
|
||||
b.getType<IntType>(), b.getI64IntegerAttr(8 * hidden_size));
|
||||
Value cstNumDir = b.create<ConstantIntOp>(
|
||||
b.getType<IntType>(), b.getI64IntegerAttr(num_directions));
|
||||
auto BType = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, 8 * hidden_size},
|
||||
cast<ValueTensorType>(W.getType()).getDtype());
|
||||
Value zerosShapeList = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(b.getType<IntType>()),
|
||||
SmallVector<Value>{cstNumDir, cstHiddenx8});
|
||||
B = b.create<AtenZerosOp>(BType, zerosShapeList, none, none, none, none);
|
||||
}
|
||||
|
||||
LstmActivations activations, activationsRev;
|
||||
// Default case (both forward and reverse)
|
||||
activations.f = "Sigmoid";
|
||||
activations.g = "Tanh";
|
||||
activations.h = "Tanh";
|
||||
activationsRev.f = "Sigmoid";
|
||||
activationsRev.g = "Tanh";
|
||||
activationsRev.h = "Tanh";
|
||||
|
||||
// forward only (also to be added for bidirectional case)
|
||||
if (num_activations >= 3) {
|
||||
activations.f = activationsList[0];
|
||||
activations.g = activationsList[1];
|
||||
activations.h = activationsList[2];
|
||||
}
|
||||
|
||||
// bidirectional
|
||||
if (num_activations == 6) {
|
||||
activationsRev.f = activationsList[3];
|
||||
activationsRev.g = activationsList[4];
|
||||
activationsRev.h = activationsList[5];
|
||||
}
|
||||
|
||||
float clip;
|
||||
if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unsupported direction attribute value. "
|
||||
"Only 'forward' is supported but '" +
|
||||
direction + "' is provided.");
|
||||
int64_t num_directions = 1 + (direction == "bidirectional");
|
||||
"clip attribute not supported");
|
||||
|
||||
int64_t input_forget;
|
||||
if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) &&
|
||||
input_forget != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "only input_forget = 0 supported. Got input_forgt = " +
|
||||
std::to_string(input_forget));
|
||||
|
||||
int64_t layout;
|
||||
if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
|
||||
std::to_string(layout));
|
||||
|
||||
auto XShape = xTy.getSizes();
|
||||
int64_t batch_size = XShape[1];
|
||||
int64_t seq_len, batch_size;
|
||||
if (layout == 0) {
|
||||
seq_len = XShape[0];
|
||||
batch_size = XShape[1];
|
||||
} else {
|
||||
seq_len = XShape[1];
|
||||
batch_size = XShape[0];
|
||||
}
|
||||
|
||||
int64_t input_size = XShape[2];
|
||||
if (num_directions != wTy.getSizes()[0])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
||||
") does not match the first dimension of wTy (" +
|
||||
std::to_string(wTy.getSizes()[0]) + ")");
|
||||
if (num_directions != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
||||
") is not equal to 1");
|
||||
|
||||
if (4 * hidden_size != wTy.getSizes()[1])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
|
||||
|
@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value R_forward = getDirection(b, 0, R);
|
||||
Value B_forward = getDirection(b, 0, B);
|
||||
|
||||
Value W_reverse, R_reverse, B_reverse;
|
||||
if (isBidirectional) {
|
||||
W_reverse = getDirection(b, 1, W);
|
||||
R_reverse = getDirection(b, 1, R);
|
||||
B_reverse = getDirection(b, 1, B);
|
||||
}
|
||||
|
||||
auto hTy = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||
xTy.getDtype());
|
||||
|
@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
|
||||
Value initial_h;
|
||||
if (binder.tensorOperandAtIndex(initial_h, 5)) {
|
||||
// default created for layout 0
|
||||
initial_h =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1)
|
||||
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||
}
|
||||
|
||||
Value initial_c;
|
||||
if (binder.tensorOperandAtIndex(initial_c, 6)) {
|
||||
// default created for layout 0
|
||||
initial_c =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1)
|
||||
initial_c = StaticTranspose(b, initial_c, 0, 1);
|
||||
}
|
||||
|
||||
// convert X from layout 1 to layout 0
|
||||
if (layout == 1)
|
||||
X = StaticTranspose(b, X, 0, 1);
|
||||
|
||||
// X, initial_h, initial_c are now in layout 0
|
||||
|
||||
Value initial_h_forward = getDirection(b, 0, initial_h);
|
||||
Value initial_c_forward = getDirection(b, 0, initial_c);
|
||||
|
||||
if (num_directions != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unsupported num_directions. Only 1 is supported but " +
|
||||
std::to_string(num_directions) + " is provided.");
|
||||
// TODO: support bidirectional LSTM by doing both directions and replacing
|
||||
// Unsqueeze with Stack
|
||||
Value initial_h_reverse, initial_c_reverse;
|
||||
if (isBidirectional) {
|
||||
initial_h_reverse = getDirection(b, 1, initial_h);
|
||||
initial_c_reverse = getDirection(b, 1, initial_c);
|
||||
}
|
||||
// Everything hereon is for the forward direction, with the direction
|
||||
// dimention squeezed out.
|
||||
|
||||
LstmWeights weights; // weights and biases
|
||||
// Everything hereon is for the forward direction (unless in bidirectional if
|
||||
// block), with the direction dimention squeezed out and all inputs in layout
|
||||
// 0 format
|
||||
|
||||
LstmWeights weights, weightsRev; // weights and biases
|
||||
|
||||
auto intConst = [&](int64_t val) {
|
||||
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
|
||||
|
@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
|
||||
auto biasType = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
|
||||
// forward
|
||||
Value Wb = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_forward,
|
||||
/*dim=*/cstZero,
|
||||
|
@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
/*start=*/recurrentWeightsStartIdx,
|
||||
/*end=*/recurrentWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
Value Wb_reverse, Rb_reverse;
|
||||
if (isBidirectional) {
|
||||
// reverse
|
||||
Wb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_reverse,
|
||||
/*dim=*/cstZero,
|
||||
/*start=*/cstZero,
|
||||
/*end=*/inputWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
Rb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_reverse,
|
||||
/*dim=*/cstZero,
|
||||
/*start=*/recurrentWeightsStartIdx,
|
||||
/*end=*/recurrentWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
}
|
||||
|
||||
// gate splitting
|
||||
auto gateBiasType = b.getType<ValueTensorType>(
|
||||
|
@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
|
||||
Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
|
||||
|
||||
auto sliceIOFC = [&](std::function<Value(Value, Value)> slicerFunction) {
|
||||
auto sliceIOFC = [&](std::function<Value(Value, Value, Value)> slicerFunction,
|
||||
Value WoB) {
|
||||
// slice into 4 components and return tuple
|
||||
return std::make_tuple(
|
||||
slicerFunction(cstZero, inputGateWeightsEndIdx),
|
||||
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx),
|
||||
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx),
|
||||
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx));
|
||||
slicerFunction(cstZero, inputGateWeightsEndIdx, WoB),
|
||||
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB),
|
||||
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB),
|
||||
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB));
|
||||
};
|
||||
|
||||
auto sliceGateBias = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, Wb, cstZero, startIdx,
|
||||
auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||
endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
|
||||
sliceIOFC(sliceGateBias);
|
||||
sliceIOFC(sliceGateBias, Wb);
|
||||
|
||||
auto sliceGateBiasR = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, Rb, cstZero, startIdx,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f,
|
||||
weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse);
|
||||
|
||||
auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||
endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
|
||||
sliceIOFC(sliceGateBiasR);
|
||||
sliceIOFC(sliceGateBiasR, Rb);
|
||||
|
||||
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, W_forward, cstZero,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f,
|
||||
weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse);
|
||||
|
||||
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, WoB, cstZero,
|
||||
startIdx, endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
|
||||
sliceIOFC(sliceGateWeightsIH);
|
||||
sliceIOFC(sliceGateWeightsIH, W_forward);
|
||||
|
||||
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, R_forward, cstZero,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) =
|
||||
sliceIOFC(sliceGateWeightsIH, W_reverse);
|
||||
|
||||
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, WoB, cstZero,
|
||||
startIdx, endIdx, cstOne);
|
||||
};
|
||||
|
||||
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
|
||||
sliceIOFC(sliceGateWeightsHH);
|
||||
sliceIOFC(sliceGateWeightsHH, R_forward);
|
||||
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
|
||||
sliceIOFC(sliceGateWeightsHH, R_reverse);
|
||||
|
||||
LstmLayerOutput lstmLayerOutput = lstm_layer(
|
||||
b, X, initial_h_forward, initial_c_forward, weights, activations);
|
||||
|
||||
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
|
||||
Value Y_h_result, Y_c_result, Y_result;
|
||||
|
||||
// if forward (unidirectional) unsqueeze and output
|
||||
auto YallDtype =
|
||||
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
|
||||
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
|
||||
auto Y_uni_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||
YallDtype);
|
||||
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype());
|
||||
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(
|
||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero);
|
||||
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>(
|
||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero);
|
||||
YallDtype);
|
||||
auto Y_res_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
|
||||
hidden_size},
|
||||
YallDtype);
|
||||
|
||||
Value Y_h_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
|
||||
|
||||
Value Y_c_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero);
|
||||
|
||||
// unsqueeze num_directions dim1 of Y
|
||||
// to create the onnx.LSTM output shape [seq_length, num_directions,
|
||||
// batch_size, hidden_size]
|
||||
Value Y_unsqueezed =
|
||||
b.create<AtenUnsqueezeOp>(yTy, lstmLayerOutput.Y, cstOne);
|
||||
Value Y_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_uni_type, lstmLayerOutput.Y, cstOne);
|
||||
|
||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed,
|
||||
Y_c_unsqueezed});
|
||||
Y_result = Y_forward;
|
||||
Y_h_result = Y_h_forward;
|
||||
Y_c_result = Y_c_forward;
|
||||
|
||||
// add bidrectional reverse layer
|
||||
// this is just flip X, lstm layer, flip results, stack
|
||||
// flip X
|
||||
Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped,
|
||||
Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list;
|
||||
LstmLayerOutput revLstmLayerOutput;
|
||||
if (isBidirectional) {
|
||||
dim0 = b.create<PrimListConstructOp>(b.getType<ListType>(intType),
|
||||
SmallVector<Value>{cstZero});
|
||||
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
|
||||
revLstmLayerOutput =
|
||||
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
|
||||
weightsRev, activationsRev);
|
||||
|
||||
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
|
||||
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||
revLstmLayerOutput.Y_h, cstZero);
|
||||
Y_c_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||
revLstmLayerOutput.Y_c, cstZero);
|
||||
Y_reverse_unflipped =
|
||||
b.create<AtenUnsqueezeOp>(Y_uni_type, revLstmLayerOutput.Y, cstOne);
|
||||
|
||||
// flip Y_rev on dim 0 [seq_len]
|
||||
Y_reverse = b.create<AtenFlipOp>(Y_uni_type, Y_reverse_unflipped, dim0);
|
||||
|
||||
// Concat forward and reverse results on dim 1
|
||||
Y_output_list =
|
||||
b.create<PrimListConstructOp>(b.getType<ListType>(Y_uni_type),
|
||||
SmallVector<Value>{Y_forward, Y_reverse});
|
||||
Y_result = b.create<AtenCatOp>(Y_res_type, Y_output_list, cstOne);
|
||||
|
||||
// Concat forward and reverse results on dim 0
|
||||
Y_h_output_list = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||
SmallVector<Value>{Y_h_forward, Y_h_reverse});
|
||||
Y_h_result =
|
||||
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_h_output_list, cstZero);
|
||||
|
||||
Y_c_output_list = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||
SmallVector<Value>{Y_c_forward, Y_c_reverse});
|
||||
Y_c_result =
|
||||
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_c_output_list, cstZero);
|
||||
}
|
||||
|
||||
if (layout == 1) {
|
||||
// Update Y, Y_h, Y_c results to layout 1
|
||||
Y_result = StaticTranspose(b, Y_result, 1, 2);
|
||||
Y_result = StaticTranspose(b, Y_result, 0, 1);
|
||||
Y_h_result = StaticTranspose(b, Y_h_result, 0, 1);
|
||||
Y_c_result = StaticTranspose(b, Y_c_result, 0, 1);
|
||||
}
|
||||
|
||||
// Only add outputs specified in onnx output node
|
||||
SmallVector<Value> actualOutputs = {Y_result, Y_h_result, Y_c_result},
|
||||
outputs;
|
||||
ValueTensorType resTy;
|
||||
for (int i = 0; i < binder.getNumResults(); ++i) {
|
||||
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
|
||||
outputs.push_back(cstNone);
|
||||
} else {
|
||||
outputs.push_back(actualOutputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, outputs);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1072,11 +1273,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
Value cstNone = b.create<ConstantNoneOp>();
|
||||
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
|
||||
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
|
||||
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));
|
||||
|
||||
// Binding arguments
|
||||
ValueTensorType yTy, Y_hType;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"At least one output must be present");
|
||||
|
@ -1132,6 +1332,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
// Validations
|
||||
auto XShape = xTy.getSizes();
|
||||
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
|
||||
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
|
||||
int64_t input_size = XShape[2];
|
||||
|
||||
std::ostringstream oss;
|
||||
|
@ -1173,6 +1374,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||
initial_h =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1) {
|
||||
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (binder.tensorOperandAtIndex(sequence_lens, 4))
|
||||
|
@ -1192,10 +1397,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
// fill in B
|
||||
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
|
||||
if (B == nullptr) {
|
||||
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
|
||||
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
|
||||
SmallVector<Value> BShapeListContents = {
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
|
||||
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
|
||||
Value BShapeList = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(intType), BShapeListContents);
|
||||
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
|
||||
|
@ -1256,51 +1461,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
|
|||
B_slices[4], B_slices[5]);
|
||||
|
||||
// Process inputs based on layout
|
||||
Value X_processed, initial_h_processed;
|
||||
ValueTensorType yTy_processed, Y_hType_processed;
|
||||
|
||||
if (layout == 0) {
|
||||
X_processed = X;
|
||||
initial_h_processed = initial_h_forward;
|
||||
yTy_processed = yTy;
|
||||
Y_hType_processed = Y_hType;
|
||||
} else {
|
||||
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
|
||||
initial_h_processed = b.create<AtenTransposeIntOp>(
|
||||
initial_h.getType(), initial_h_forward, cstZero, cstOne);
|
||||
|
||||
auto yTySizes = yTy.getSizes();
|
||||
auto Y_hTypeSizes = Y_hType.getSizes();
|
||||
|
||||
yTy_processed = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
|
||||
yTySizes[3]},
|
||||
yTy.getDtype());
|
||||
|
||||
Y_hType_processed = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
|
||||
Y_hTypeSizes[2]},
|
||||
Y_hType.getDtype());
|
||||
if (layout == 1) {
|
||||
X = StaticTranspose(b, X, 0, 1);
|
||||
}
|
||||
|
||||
// Weights and biases ready. Calling GRU layer to insert the actual ops.
|
||||
GruLayerOutput gruLayerOutput =
|
||||
gru_layer(b, X_processed, initial_h_processed, weights, activations,
|
||||
linear_before_reset);
|
||||
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
|
||||
activations, linear_before_reset);
|
||||
|
||||
// Process outputs based on layout
|
||||
Value Y_final, Y_h_final;
|
||||
Value Y_final;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
|
||||
Y_final = cstNone;
|
||||
} else {
|
||||
if (layout == 0) {
|
||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
||||
} else {
|
||||
auto Y_transposed = b.create<AtenTransposeIntOp>(
|
||||
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
|
||||
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
|
||||
Type yTy_original = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||
yTy.getDtype());
|
||||
Y_final =
|
||||
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
|
||||
Y_final = StaticTranspose(b, Y_final, 1, 2);
|
||||
Y_final = StaticTranspose(b, Y_final, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
auto Y_h_transposed = b.create<AtenTransposeIntOp>(
|
||||
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
|
||||
Value Y_h_final;
|
||||
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
|
||||
Y_h_final = cstNone;
|
||||
} else {
|
||||
if (layout == 0) {
|
||||
Y_h_final =
|
||||
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
|
||||
} else {
|
||||
Type y_hTy_original = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
|
||||
Y_hType.getDtype());
|
||||
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
|
||||
cstZero);
|
||||
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
|
||||
|
|
|
@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
|
|||
template <typename OpTy, typename OpAdaptor>
|
||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
int64_t &dim,
|
||||
SmallVector<Value> &resultShape,
|
||||
SmallVector<Value> &offsets,
|
||||
SmallVector<Value> &strides) {
|
||||
|
@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
|
@ -1658,10 +1658,17 @@ public:
|
|||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
||||
// assert dynamic squeeze dim size == 1
|
||||
if (inputType.isDynamicDim(dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
||||
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
|
||||
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
|
||||
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
|
||||
Value cmp = rewriter.create<arith::CmpIOp>(
|
||||
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
op.getLoc(), cmp,
|
||||
rewriter.getStringAttr(
|
||||
"Expected dynamic squeeze dim size to be statically 1"));
|
||||
}
|
||||
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
@ -1671,7 +1678,7 @@ public:
|
|||
|
||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||
// `aten.squeeze` will behave as an identity operation.
|
||||
if (inputType.getDimSize(dim) != 1) {
|
||||
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||
return success();
|
||||
}
|
||||
|
@ -1857,14 +1864,46 @@ public:
|
|||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
SmallVector<Value> resultShape, offsets, strides;
|
||||
int64_t dim;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
||||
AtenSliceTensorOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If stride is negative, then flip the input tensor corresponding to that
|
||||
// dim, update the stride for flipped tensor by multiplying it by -1, and
|
||||
// update the offset as follows:
|
||||
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
|
||||
//
|
||||
// For example:
|
||||
// Input = [0, 1, 2, 3, 4, 5]
|
||||
// stride = [-2], result_shape = [2], offset = [3]
|
||||
// Result = [3, 1]
|
||||
// After flipping:
|
||||
// Input = [5, 4, 3, 2, 1, 0]
|
||||
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
|
||||
// Result = [3, 1]
|
||||
|
||||
Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
|
||||
SmallVector<int64_t>{dim});
|
||||
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, strides[dim], zero);
|
||||
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
|
||||
Value resShapeMulStride =
|
||||
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
|
||||
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
|
||||
Value flippedOffset =
|
||||
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
|
||||
offsets[dim] = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegativeStride, flippedOffset, offsets[dim]);
|
||||
|
||||
input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
|
||||
flippedInput, input);
|
||||
|
||||
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
||||
auto sliceType = RankedTensorType::get(
|
||||
dynShape, resultType.getElementType(), resultType.getEncoding());
|
||||
|
@ -2095,12 +2134,11 @@ public:
|
|||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
SmallVector<Value> resultShape, offsets, strides;
|
||||
int64_t dim;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -2573,6 +2611,167 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUnfoldOp : public OpConversionPattern<AtenUnfoldOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto self = adaptor.getSelf();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
|
||||
int64_t dimension;
|
||||
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dimension");
|
||||
}
|
||||
int64_t size;
|
||||
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) {
|
||||
return rewriter.notifyMatchFailure(op, "only support constant int size");
|
||||
}
|
||||
int64_t step;
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||
return rewriter.notifyMatchFailure(op, "only support constant int step");
|
||||
}
|
||||
|
||||
if (step <= 0) {
|
||||
return rewriter.notifyMatchFailure(op, "step must be greater than zero.");
|
||||
}
|
||||
|
||||
int64_t selfRank = selfType.getRank();
|
||||
|
||||
// Zero-Rank case
|
||||
if (selfRank == 0) {
|
||||
// Empty tensor
|
||||
if (size == 0) {
|
||||
RankedTensorType resultType =
|
||||
RankedTensorType::get({0}, selfType.getElementType());
|
||||
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, resultType.getShape(), resultType.getElementType());
|
||||
|
||||
rewriter.replaceOp(op, emptyTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value unsqueezedSelf = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, RankedTensorType::get({1}, selfType.getElementType()), self,
|
||||
ArrayRef<ReassociationIndices>{});
|
||||
rewriter.replaceOp(op, unsqueezedSelf);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto shape = selfType.getShape();
|
||||
|
||||
if (dimension < 0) {
|
||||
dimension = toPositiveDim(dimension, selfRank);
|
||||
}
|
||||
if (!isValidDim(dimension, selfRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dimension out of range");
|
||||
}
|
||||
|
||||
Value dimSize = rewriter.create<tensor::DimOp>(loc, self, dimension);
|
||||
|
||||
Value sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
|
||||
Value sizeCheck = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ule, sizeValue, dimSize);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, sizeCheck,
|
||||
rewriter.getStringAttr("size must be <= target dimension"));
|
||||
|
||||
/* Calculate output shape of unfold op:
|
||||
* https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
|
||||
* outputShape[dimension] is set to numBlocks, with size appended as an
|
||||
* additional dimension
|
||||
*/
|
||||
SmallVector<OpFoldResult> outputShape;
|
||||
for (int64_t i = 0; i < selfRank; i++) {
|
||||
if (i == dimension) {
|
||||
outputShape.push_back(getDynamicOrStaticNumBlocks(
|
||||
rewriter, loc, shape[dimension], dimSize, size, step));
|
||||
} else if (shape[i] == ShapedType::kDynamic) {
|
||||
outputShape.push_back(
|
||||
OpFoldResult(rewriter.create<tensor::DimOp>(loc, self, i)));
|
||||
} else {
|
||||
outputShape.push_back(rewriter.getIndexAttr(shape[i]));
|
||||
}
|
||||
}
|
||||
outputShape.push_back(rewriter.getIndexAttr(size));
|
||||
|
||||
// Empty tensor to insert values into
|
||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, outputShape, selfType.getElementType());
|
||||
|
||||
/**
|
||||
* Use reindexing to map output indices to input indices
|
||||
* i.e. In output of rank 3 case:
|
||||
* (i, j, k) => (i', j') where i' = i * step + k and j' = j
|
||||
* if dimension == 0
|
||||
* (i, j, k) => (i', j') where i' = i and j' = j * step + k
|
||||
* if dimension == 1
|
||||
*/
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
SmallVector<AffineExpr> outputExprs;
|
||||
for (int dim = 0; dim < selfRank; ++dim) {
|
||||
if (dim == dimension) {
|
||||
auto idxLast = getAffineDimExpr(selfRank, context);
|
||||
auto idxDimension = getAffineDimExpr(dimension, context);
|
||||
|
||||
AffineExpr dimIdx =
|
||||
idxLast + idxDimension * rewriter.getAffineConstantExpr(step);
|
||||
outputExprs.push_back(dimIdx);
|
||||
} else {
|
||||
outputExprs.push_back(getAffineDimExpr(dim, context));
|
||||
}
|
||||
}
|
||||
|
||||
int64_t outputRank = selfRank + 1;
|
||||
auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context);
|
||||
auto outputAffineMap =
|
||||
AffineMap::getMultiDimIdentityMap(outputRank, context);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outputRank, utils::IteratorType::parallel);
|
||||
|
||||
Value result =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputTensor.getType(), self, outputTensor,
|
||||
ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes,
|
||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc,
|
||||
int64_t shapeDim, Value dimSize,
|
||||
int64_t size, int64_t step) const {
|
||||
/**
|
||||
* numBlocks = (shape[dimension] - size) // step + 1
|
||||
*/
|
||||
if (shapeDim == ShapedType::kDynamic) {
|
||||
Value numBlocksSubOp = rewriter.create<arith::SubIOp>(
|
||||
loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, size));
|
||||
Value numBlocksDivOp = rewriter.create<arith::DivUIOp>(
|
||||
loc, numBlocksSubOp,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, step));
|
||||
Value numBlocks = rewriter.create<arith::AddIOp>(
|
||||
loc, rewriter.create<arith::ConstantIndexOp>(loc, 1), numBlocksDivOp);
|
||||
return OpFoldResult(numBlocks);
|
||||
}
|
||||
|
||||
int64_t staticNumBlocks = (shapeDim - size) / step + 1;
|
||||
return rewriter.getIndexAttr(staticNumBlocks); // Use static value
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
|
||||
public:
|
||||
|
@ -2641,7 +2840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
/*benefit=*/200);
|
||||
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
||||
/*benefit=*/100);
|
||||
|
||||
target.addIllegalOp<AtenUnfoldOp>();
|
||||
patterns.add<ConvertAtenUnfoldOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||
|
|
|
@ -301,14 +301,9 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfRank =
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
Type elementType =
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
||||
Value c1 =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
|
||||
SmallVector<int64_t> axis;
|
||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
||||
|
@ -321,40 +316,8 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||
// dims won't be used.
|
||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
||||
for (auto flipDim : axis)
|
||||
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
||||
|
||||
Value initTensor = createZeroInitTensor(
|
||||
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
selfRank, utils::IteratorType::parallel);
|
||||
SmallVector<AffineMap> indexingMaps(
|
||||
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
|
||||
Value flipped =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, self.getType(), self, initTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> indices;
|
||||
for (auto i = 0; i < selfRank; i++)
|
||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
for (auto flipDim : axis) {
|
||||
indices[flipDim] = b.create<arith::SubIOp>(
|
||||
loc, dims[flipDim], indices[flipDim]);
|
||||
}
|
||||
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
|
||||
.getResult();
|
||||
b.create<linalg::YieldOp>(loc, res);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1300,10 +1263,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
// Special depthwise case: Cin = Cout = groups.
|
||||
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
||||
// of groups) to be depthwise in their documentation, but the linalg ops
|
||||
|
@ -1315,21 +1274,45 @@ public:
|
|||
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
||||
weightShape[1] == 1) {
|
||||
// Collapse weight shape (C/G == 1)
|
||||
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
||||
weightShape[2], weightShape[3]};
|
||||
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
|
||||
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
|
||||
for (unsigned i = 0; i < numSpatialDims; i++) {
|
||||
collapsedDims.push_back({i + 2});
|
||||
collapsedShape.push_back(weightShape[i + 2]);
|
||||
}
|
||||
Type collapsedType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, collapsedType, weight, collapsedDims);
|
||||
if (!inputZp) {
|
||||
switch (numSpatialDims) {
|
||||
case 1:
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv1DNcwCwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
break;
|
||||
case 2:
|
||||
conv = rewriter
|
||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||
loc, outputTensor.getType(),
|
||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||
stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
break;
|
||||
default:
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 1D and 2D depthwise convolution "
|
||||
"supported for special case of group convolution");
|
||||
};
|
||||
} else {
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D depthwise quantized convolution "
|
||||
"supported for special case of group convolution");
|
||||
|
||||
// currently, the only named depthwise qconv op is nhwc_hwc
|
||||
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
||||
// linalg conv result nhwc -> nchw
|
||||
|
@ -1376,6 +1359,10 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (numSpatialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
// Grouped case, use the grouped conv linalg op
|
||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||
|
|
|
@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||
return createEqual(b, loc, floatDtype, self, zero);
|
||||
}
|
||||
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
|
||||
auto ctype = cast<ComplexType>(
|
||||
cast<RankedTensorType>(converter->convertType(complex.getType()))
|
||||
.getElementType());
|
||||
Type stype = ctype.getElementType();
|
||||
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
|
||||
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
|
||||
}
|
||||
if (isa<AtenAbsOp>(op)) {
|
||||
if (isa<IntegerType>(payloadArgs[0].getType()))
|
||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||
|
@ -1590,22 +1600,22 @@ public:
|
|||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
||||
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
|
||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
|
||||
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
|
||||
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
||||
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
|
||||
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
|
@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
|
|
|
@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
|
|||
.getResult(0);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Flips an input tensor based on the values of axis list.
|
||||
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
|
||||
Value input, SmallVector<int64_t> axis) {
|
||||
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
|
||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||
// dims won't be used.
|
||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||
for (auto flipDim : axis)
|
||||
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
||||
|
||||
Value initTensor = createZeroInitTensor(
|
||||
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(selfRank,
|
||||
utils::IteratorType::parallel);
|
||||
SmallVector<AffineMap> indexingMaps(
|
||||
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
|
||||
Value flipped =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, input.getType(), input, initTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> indices;
|
||||
for (auto i = 0; i < selfRank; i++)
|
||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
for (auto flipDim : axis) {
|
||||
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
|
||||
indices[flipDim]);
|
||||
}
|
||||
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
|
||||
.getResult();
|
||||
b.create<linalg::YieldOp>(loc, res);
|
||||
})
|
||||
.getResult(0);
|
||||
return flipped;
|
||||
}
|
||||
|
|
|
@ -325,7 +325,8 @@ public:
|
|||
lhsContractingDim, rhsContractingDim);
|
||||
output = rewriter
|
||||
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
dotDimensionNumbers, nullptr,
|
||||
nullptr)
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
@ -494,7 +495,7 @@ public:
|
|||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!isa<Torch::NoneType>(biasTy)) {
|
||||
|
|
|
@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
|
||||
// Max pooling
|
||||
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||
AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
|
@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// AtenMaxPool1dWithIndicesOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
|
||||
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
||||
auto outValTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||
auto outIdxTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||
|
||||
if (inputRank <= 1) {
|
||||
return op.emitError(
|
||||
"max_pooling1d only supports inputs with rank higher than 1");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.getKernelSize(),
|
||||
m_TorchListOfConstantInts(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 1);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloDilation.begin() + inputRank - 1);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - 1);
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
|
||||
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
// no need to reshape here for max_pool_1d. Need to make sure the iota
|
||||
// dimension. dim=inputRank-2 or dim=inputRank-1?
|
||||
auto indexTensor =
|
||||
rewriter
|
||||
.create<stablehlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
|
||||
.getResult();
|
||||
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||
|
||||
// add block.
|
||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||
auto *firstValArg = block.args_begin();
|
||||
auto *firstIdxArg = std::next(firstValArg);
|
||||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
|
||||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// Get smaller index if compared values are equal.
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
|
@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
|||
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||
#undef INSERT_ATEN_POOLING_PATTERN
|
||||
|
|
|
@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
}
|
||||
}
|
||||
|
||||
if (isa<AtenAllOp>(op)) {
|
||||
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||
auto constAttr =
|
||||
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
|
@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
|||
AtenLinalgVectorNormOp>(op)) {
|
||||
result = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenAllOp>(op)) {
|
||||
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||
result = rewriter.create<stablehlo::AndOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
||||
|
@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
|||
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
||||
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
|
||||
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
||||
|
||||
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -161,12 +161,70 @@ public:
|
|||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
unsigned getBitWidth(Type type) const {
|
||||
if (auto complexTy = dyn_cast<ComplexType>(type))
|
||||
return 2 * getBitWidth(complexTy.getElementType());
|
||||
return type.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!rankType)
|
||||
return op.emitError("Only ranked tensor types are currently supported");
|
||||
return op.emitError("Only ranked tensor types are currently supported.");
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// support AtenViewDtypeOp
|
||||
if (isa<AtenViewDtypeOp>(op)) {
|
||||
auto self = adaptor.getSelf();
|
||||
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
|
||||
|
||||
// infer the result shape
|
||||
auto operandElt = rankType.getElementType();
|
||||
auto targetElt = baseResultTy.getDtype();
|
||||
auto operandEltBitWidth = getBitWidth(operandElt);
|
||||
auto targetEltBitWidth = getBitWidth(targetElt);
|
||||
auto operandSizes = rankType.getShape();
|
||||
SmallVector<int64_t> castShape(operandSizes);
|
||||
if (operandEltBitWidth > targetEltBitWidth) {
|
||||
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
|
||||
castShape.push_back(last_size);
|
||||
} else if (operandEltBitWidth < targetEltBitWidth) {
|
||||
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
|
||||
if (!ShapedType::isDynamic(castShape.back()) and
|
||||
last_size != castShape.back()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The last dim size is not equal to targetEltBitWidth / "
|
||||
"operandEltBitWidth.");
|
||||
} else {
|
||||
castShape.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
auto resultType =
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
baseResultTy);
|
||||
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only support static output shape.");
|
||||
}
|
||||
|
||||
auto castType =
|
||||
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
|
||||
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
|
||||
loc,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
castType),
|
||||
self);
|
||||
|
||||
auto reshape =
|
||||
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
|
||||
|
||||
rewriter.replaceOp(op, reshape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// collect Value of dims
|
||||
SmallVector<Value, 4> dimSizes;
|
||||
|
@ -174,7 +232,6 @@ public:
|
|||
return op.emitError("Dims size must be a list of Scalar");
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
|
@ -236,6 +293,13 @@ public:
|
|||
SmallVector<Value, 4> &dimSizes) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
|
||||
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &dimSizes) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
||||
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
|
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
|||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||
#undef INSERT_VIEW_OP_PATTERN
|
||||
|
|
|
@ -1497,6 +1497,79 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
Type inputElementType =
|
||||
cast<RankedTensorType>(input.getType()).getElementType();
|
||||
|
||||
// Converting the input element type to the result's element type.
|
||||
// The only possible mismatch would be when the input element type is an
|
||||
// integer but not `si64`. Therefore, we directly convert the input to
|
||||
// `si64`. Rest all cases are handled in the dtype definition for this op.
|
||||
if (elementType != inputElementType) {
|
||||
Value torchInput = convertTensorToDtype(
|
||||
rewriter, loc, op.getSelf(),
|
||||
rewriter.getIntegerType(64, IntegerType::Signed));
|
||||
input = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(torchInput.getType()),
|
||||
torchInput);
|
||||
}
|
||||
|
||||
int64_t inputRank = resultType.getRank();
|
||||
Value dtype = op.getDtype();
|
||||
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: dtype argument not supported");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only constant dim value is supported");
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "invalid dim");
|
||||
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
|
||||
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
|
||||
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
|
||||
|
||||
SmallVector<Value> accSizes(sizes);
|
||||
accSizes.erase(accSizes.begin() + dim);
|
||||
SmallVector<int64_t> accStatic(
|
||||
makeShapeTorchCompatible(resultType.getShape()));
|
||||
accStatic.erase(accStatic.begin() + dim);
|
||||
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
|
||||
Type accType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
|
||||
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
|
||||
|
||||
Value result = createTMTensorScanOp(
|
||||
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
||||
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
||||
Value prod =
|
||||
(isa<mlir::FloatType>(input.getType())
|
||||
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
|
||||
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
|
||||
b.create<TMTensor::YieldOp>(loc, prod);
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
||||
public:
|
||||
|
@ -2240,6 +2313,8 @@ public:
|
|||
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCumprodOp>();
|
||||
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
||||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||
context);
|
||||
|
|
|
@ -153,9 +153,15 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Unable to extract the scalar constant");
|
||||
|
||||
int64_t numElem = 1;
|
||||
for (int64_t dim : dshape)
|
||||
numElem *= dim;
|
||||
|
||||
if (isa<mlir::FloatType>(dtype)) {
|
||||
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
||||
(isFloat ? doubleValue : intValue),
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<float>(
|
||||
rewriter, op,
|
||||
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
|
||||
dshape, dtype)
|
||||
.value();
|
||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
|
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
bool d = isFloat ? static_cast<bool>(doubleValue)
|
||||
: static_cast<bool>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<bool>(
|
||||
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 32) {
|
||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
"of destination type");
|
||||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
|
||||
.value();
|
||||
}
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||
|
@ -891,8 +900,6 @@ public:
|
|||
if (!result)
|
||||
return failure();
|
||||
|
||||
// TBD - support dtype casting.
|
||||
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
|
@ -2842,8 +2849,12 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||
}
|
||||
|
||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>(
|
||||
rewriter, op.getOperation(), dimListInt, {selfRank});
|
||||
SmallVector<int32_t> dimListInt32;
|
||||
for (auto v : dimListInt)
|
||||
dimListInt32.push_back(v);
|
||||
|
||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
|
||||
rewriter, op.getOperation(), dimListInt32, {selfRank});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||
|
@ -3819,6 +3830,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only RankedTensorType inputs are currently supported");
|
||||
|
||||
auto index = adaptor.getIndex();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
|
||||
if (!indexType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only RankedTensorType indices are currently supported");
|
||||
|
||||
auto inputShape = inputType.getShape();
|
||||
int inputRank = inputType.getRank();
|
||||
|
||||
if (indexType.getRank() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 index tensor is currently not supported");
|
||||
|
||||
// Dynamic shape check
|
||||
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenIndexSelectOp: support for dynamic input "
|
||||
"shape not implemented");
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexType.getShape(),
|
||||
rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Get positive dim
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Value `dim` should be a torch constant int");
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
|
||||
|
||||
// Get the output type
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Reshape and expand the index tensor to have same rank and same dimensions
|
||||
// (except for the targeted dim) as the input
|
||||
//
|
||||
// For example:
|
||||
// Input shape = (4, 5, 6)
|
||||
// Index vector shape = (2)
|
||||
// Targeted dim = 1
|
||||
// Reshaped and expanded index vector shape = (4, 2, 6)
|
||||
//
|
||||
// By reshaping and expanding the index vector, we can supply it into the
|
||||
// gather op to mimic the functionality of aten.index_select
|
||||
SmallVector<int64_t> indicesInputRankShape;
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (i == dim) {
|
||||
indicesInputRankShape.push_back(indexType.getShape()[0]);
|
||||
} else {
|
||||
indicesInputRankShape.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
auto indicesInputRankType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), indicesInputRankType, index,
|
||||
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
|
||||
|
||||
SmallVector<int64_t> tileShape(indicesInputRankShape);
|
||||
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (tileShape[i] == 1 && i != dim) {
|
||||
tileShape[i] = inputShape[i];
|
||||
expandedIndicesShape[i] = inputShape[i];
|
||||
} else {
|
||||
tileShape[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto tileType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto expandedIndices = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), tileType, reshapedIndices.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(tileShape));
|
||||
|
||||
// convert torch style index and dim into tf style indices
|
||||
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
|
||||
auto indicesTf = tosa::convertTorchIndexToTfIndices(
|
||||
rewriter, op, input, expandedIndices.getResult(), dim);
|
||||
if (!indicesTf)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices failed");
|
||||
|
||||
// do the tf gathernd algorithm with tf style indices as input.
|
||||
auto result =
|
||||
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||
|
@ -5200,7 +5329,7 @@ public:
|
|||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
@ -5216,18 +5345,48 @@ public:
|
|||
op, "Only Tensor types with static shapes are currently supported");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
Value constOp;
|
||||
|
||||
Value fillValueTargetTensor;
|
||||
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
|
||||
// Reshape value tensor to have same rank and shape as input
|
||||
auto inputRank =
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
|
||||
auto fillValue = adaptor.getValue();
|
||||
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
|
||||
if (!fillValueType)
|
||||
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
|
||||
auto fillValueElemTy = fillValueType.getElementType();
|
||||
|
||||
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
|
||||
|
||||
auto fillValueMatchedInputRankType = RankedTensorType::get(
|
||||
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
|
||||
fillValueElemTy);
|
||||
|
||||
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), fillValueMatchedInputRankType, fillValue,
|
||||
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
|
||||
|
||||
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
|
||||
fillValueElemTy),
|
||||
fillValueMatchedInputRankTensor.getResult(),
|
||||
makeShapeTorchCompatible(outType.getShape()));
|
||||
} else {
|
||||
if (failed(torchScalarToTosaTensor(
|
||||
rewriter, op, op.getValue(), constOp, outElemTy,
|
||||
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
|
||||
makeShapeTorchCompatible(outType.getShape()))))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value must be a Scalar constant");
|
||||
op, "Fill value must be a scalar constant");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||
fillValueTargetTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -5647,8 +5806,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Template to create support tril mask tensor for aten.tril
|
||||
// legalization
|
||||
// Template to create supporting tril mask tensor for aten.tril
|
||||
template <typename T>
|
||||
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
|
@ -5671,28 +5829,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
|||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||
}
|
||||
|
||||
// Function to get tril mask tensor based on input type
|
||||
// for aten.tril legalization
|
||||
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t diagonal, Type type) {
|
||||
return TypeSwitch<Type, Value>(type)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
||||
case 32:
|
||||
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
||||
case 64:
|
||||
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
}
|
||||
|
||||
// Legalization for aten.tril
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||
|
@ -5740,14 +5876,31 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||
|
||||
// Define shape for mask tensor based on rank
|
||||
SmallVector<int64_t> constShape;
|
||||
SmallVector<int64_t> maskShape;
|
||||
for (auto i = 0; i < selfRank - 2; i++)
|
||||
constShape.push_back(1);
|
||||
constShape.push_back(h);
|
||||
constShape.push_back(w);
|
||||
maskShape.push_back(1);
|
||||
maskShape.push_back(h);
|
||||
maskShape.push_back(w);
|
||||
|
||||
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
||||
resultType.getElementType());
|
||||
Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createTrilMask<float>(rewriter, op, maskShape,
|
||||
h, w, diagonal);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createTrilMask<bool>(rewriter, op, maskShape,
|
||||
h, w, diagonal);
|
||||
case 32:
|
||||
return createTrilMask<int32_t>(
|
||||
rewriter, op, maskShape, h, w, diagonal);
|
||||
case 64:
|
||||
return createTrilMask<int64_t>(
|
||||
rewriter, op, maskShape, h, w, diagonal);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||
/*shift=*/0);
|
||||
|
@ -5755,6 +5908,311 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.flip
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||
AtenFlipOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are currently supported");
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant dims are currently supported");
|
||||
|
||||
auto selfRank = selfTy.getRank();
|
||||
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
Value result = self;
|
||||
|
||||
for (auto &dim : dims) {
|
||||
dim = toPositiveDim(dim, selfRank);
|
||||
if (!isValidDim(dim, selfRank))
|
||||
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||
|
||||
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
|
||||
static_cast<int32_t>(dim));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.round:
|
||||
// Rounds elements of input to the nearest integer.
|
||||
// Implements "round half to even" to break ties when a number is equidistant
|
||||
// from two integers.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
|
||||
AtenRoundOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// To round to the nearest integer, we will consider the fractional part of
|
||||
// the input element (= input element - integer part of element). If the
|
||||
// fractional part is smaller than 0.5, round the number down. If the
|
||||
// fractional part is 0.5, apply "round half to even" rule. If the fractional
|
||||
// part is greater than 0.5, round up.
|
||||
//
|
||||
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
|
||||
// res = floor(input)
|
||||
// else:
|
||||
// res = ceil(input)
|
||||
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
|
||||
|
||||
auto resultTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
auto boolTy =
|
||||
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
|
||||
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
auto oneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
|
||||
|
||||
auto two =
|
||||
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
|
||||
|
||||
auto floorInput =
|
||||
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
|
||||
|
||||
// input - floor(input)
|
||||
auto fractionalPart = rewriter.create<tosa::SubOp>(
|
||||
op->getLoc(), resultTy, self, floorInput.getResult());
|
||||
|
||||
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
|
||||
|
||||
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
|
||||
|
||||
auto floorDivResult = rewriter.create<tosa::FloorOp>(
|
||||
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
|
||||
|
||||
// (floor(input) // 2) * 2
|
||||
auto evenComparison = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
|
||||
|
||||
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
|
||||
auto floorInputEven = rewriter.create<tosa::EqualOp>(
|
||||
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
|
||||
|
||||
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
|
||||
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
|
||||
|
||||
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
|
||||
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
|
||||
|
||||
// (frac == 0.5) && (floor(input) % 2 == 0)
|
||||
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
|
||||
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
|
||||
floorInputEven.getResult());
|
||||
|
||||
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
|
||||
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
|
||||
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
|
||||
fracEqualOneHalfCond.getResult());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
|
||||
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
|
||||
ceilInput.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||
template <typename T>
|
||||
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t offset) {
|
||||
SmallVector<T> vec;
|
||||
|
||||
for (int64_t i = 0; i < h; i++) {
|
||||
for (int64_t j = 0; j < w; j++) {
|
||||
// Positive offset value moves above the main diagonal, while negative
|
||||
// diagonal value moves below the main diagonal.
|
||||
if (i + offset == j) {
|
||||
vec.push_back(static_cast<T>(1));
|
||||
} else {
|
||||
vec.push_back(static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||
}
|
||||
|
||||
// Legalization for aten.diagonal
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||
AtenDiagonalOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
||||
// Rank below 2 not accepted
|
||||
auto selfRank = selfType.getRank();
|
||||
if (selfRank <= 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||
|
||||
if (!selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only static shapes are supported");
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
if (!resultType)
|
||||
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
auto resultElemTy = resultType.getElementType();
|
||||
|
||||
int64_t offset, dim1, dim2;
|
||||
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||
offset = 0;
|
||||
|
||||
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
|
||||
dim1 = 0;
|
||||
} else {
|
||||
dim1 = toPositiveDim(dim1, selfRank);
|
||||
}
|
||||
|
||||
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
|
||||
dim2 = 1;
|
||||
} else {
|
||||
dim2 = toPositiveDim(dim2, selfRank);
|
||||
}
|
||||
|
||||
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||
int64_t h = selfShape[dim1];
|
||||
int64_t w = selfShape[dim2];
|
||||
|
||||
// Overflowing offset not supported
|
||||
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Offset greater or equal than shape not supported");
|
||||
|
||||
int64_t targetDim1 = selfRank - 2;
|
||||
int64_t targetDim2 = selfRank - 1;
|
||||
|
||||
Value selfTransposed = self;
|
||||
SmallVector<int64_t> transposedInputShape = selfShape;
|
||||
RankedTensorType transposedInputType = selfType;
|
||||
|
||||
// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
|
||||
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
|
||||
// we can consistently create the diagonal mask tensor.
|
||||
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
|
||||
SmallVector<int32_t> transposedDims;
|
||||
transposedInputShape.clear();
|
||||
|
||||
for (int64_t i = 0; i < selfRank; ++i) {
|
||||
if (i == dim1 || i == dim2)
|
||||
continue;
|
||||
transposedDims.push_back(i);
|
||||
}
|
||||
transposedDims.push_back(dim1);
|
||||
transposedDims.push_back(dim2);
|
||||
|
||||
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op,
|
||||
/*vec=*/transposedDims,
|
||||
/*shape=*/{static_cast<int32_t>(selfRank)});
|
||||
|
||||
for (auto &dim : transposedDims)
|
||||
transposedInputShape.push_back(selfShape[dim]);
|
||||
|
||||
transposedInputType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||
|
||||
selfTransposed = rewriter.create<tosa::TransposeOp>(
|
||||
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
|
||||
}
|
||||
|
||||
// Define shape for mask tensor based on rank
|
||||
SmallVector<int64_t> maskShape;
|
||||
for (auto i = 0; i < selfRank - 2; i++)
|
||||
maskShape.push_back(1);
|
||||
maskShape.push_back(h);
|
||||
maskShape.push_back(w);
|
||||
|
||||
Value diagonalMask =
|
||||
TypeSwitch<Type, Value>(resultElemTy)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
case 32:
|
||||
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
case 64:
|
||||
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
|
||||
Value diagonalTensor = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
|
||||
/*shift=*/0);
|
||||
|
||||
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
|
||||
auto targetReduceDim = resultShape[resultType.getRank() - 1];
|
||||
|
||||
// If transposedInputShape[targetDim1] (or h) is greater than the innermost
|
||||
// dim of the result, we won't get the correct shape when we reduce sum along
|
||||
// the innermost dim to get the result. Therefore, we have to slice the
|
||||
// transposed tensor so that transposedInputShape[targetDim1] ==
|
||||
// targetReduceDim.
|
||||
if (h > targetReduceDim) {
|
||||
transposedInputShape[targetDim1] = targetReduceDim;
|
||||
transposedInputType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||
SmallVector<int64_t> startSlice(selfRank, 0);
|
||||
SmallVector<int64_t> sizeSlice =
|
||||
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
|
||||
if (offset < 0)
|
||||
startSlice[targetDim1] = std::abs(offset);
|
||||
diagonalTensor = rewriter.create<tosa::SliceOp>(
|
||||
op->getLoc(), transposedInputType, diagonalTensor,
|
||||
rewriter.getDenseI64ArrayAttr(startSlice),
|
||||
rewriter.getDenseI64ArrayAttr(sizeSlice));
|
||||
}
|
||||
|
||||
// Apply Reduce Sum to get the result
|
||||
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
auto reduceDimAttr =
|
||||
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
|
||||
auto result =
|
||||
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
|
||||
reduceDimAttr, /*keep_dims=*/false);
|
||||
|
||||
rewriter.replaceOp(op, result.value());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -5986,11 +6444,13 @@ public:
|
|||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
|
||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
||||
#undef INSERT_FILL_SCALAR_PATTERN
|
||||
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||
#undef INSERT_FILL_PATTERN
|
||||
|
||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
|
@ -6060,6 +6520,10 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -23,6 +23,15 @@ namespace tosa {
|
|||
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// This function is a helper for `convertTorchIndexToTfIndices`.
|
||||
//
|
||||
// We convert PyTorch index to TensorFlow-style indices so that we can use
|
||||
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
|
||||
// and Scatter operators to TOSA using TensorFlow-style indices.
|
||||
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
|
||||
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
|
||||
// to gather/scatter elements, while in TensorFlow, the indices point directly
|
||||
// to positions that you want to gather/scatter elements.
|
||||
std::optional<Value>
|
||||
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
|
||||
|
@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
|||
unsigned indexRank = indexShape.size();
|
||||
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
|
||||
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
|
||||
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
|
||||
|
||||
// Create torch.meshgrid inputs
|
||||
// Example: indexShape=[1,4,2]
|
||||
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
|
||||
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
|
||||
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
|
||||
for (int i = 0; i < indexShape[dim]; i++) {
|
||||
for (int i = 0; i < indexShape[dim]; i++)
|
||||
indicesMetaElement.push_back(i);
|
||||
}
|
||||
|
||||
// Compute total number of meta element repeat times:
|
||||
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
|
||||
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
|
||||
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
|
||||
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
|
||||
for (int i = 0; i < static_cast<int>(indexRank); i++) {
|
||||
if (i == dim) {
|
||||
continue;
|
||||
} else {
|
||||
indicesMetaElementRepeatTimes *= indexShape[i];
|
||||
}
|
||||
}
|
||||
int preDimMetaElementRepeatTimes = 1;
|
||||
int postDimMetaElementRepeatTimes = 1;
|
||||
|
||||
if (dim != static_cast<int>(indexShape.size()) - 1) {
|
||||
// Create one dim indices for index except for last dim
|
||||
// Create indices raw vector.
|
||||
// torch.stack(torch.meshgrid)
|
||||
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
|
||||
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
|
||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
||||
elementId++) {
|
||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
||||
}
|
||||
}
|
||||
} else { // Create the one dim indices for last dim of index
|
||||
// Create indices raw vector
|
||||
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
|
||||
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
|
||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
||||
// Compute total number of times meta element range should repeat
|
||||
// = product(indexShape[0:dim])
|
||||
// dim0: preDimMetaElementRepeatTimes = 1
|
||||
// dim1: preDimMetaElementRepeatTimes = 1
|
||||
// dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
|
||||
for (int i = 0; i < dim; i++)
|
||||
preDimMetaElementRepeatTimes *= indexShape[i];
|
||||
|
||||
// Compute total number of times meta element repeat
|
||||
// = product(indexShape[dim+1:indexRank])
|
||||
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
|
||||
// dim1: postDimMetaElementRepeatTimes = 2
|
||||
// dim2: postDimMetaElementRepeatTimes = 1
|
||||
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
|
||||
postDimMetaElementRepeatTimes *= indexShape[i];
|
||||
|
||||
// Example using dim1:
|
||||
// preDimMetaElementRepeatTimes = 1
|
||||
// postDimMetaElementRepeatTimes = 2
|
||||
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||
// [0 0 1 1 2 2 3 3]
|
||||
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||
// [0 0 1 1 2 2 3 3]
|
||||
//
|
||||
// Let's use a clearer example:
|
||||
// indexShape = [3, 4, 2]
|
||||
// Target dim = 1
|
||||
// => preDimMetaElementRepeatTimes = 3
|
||||
// postDimMetaElementRepeatTimes = 2
|
||||
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||
// [0 0 1 1 2 2]
|
||||
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
|
||||
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
|
||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
||||
elementId++) {
|
||||
for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
|
||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -132,12 +132,28 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
Type elemTy) {
|
||||
Value initTensor =
|
||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||
Value c0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
||||
|
||||
Type fillValElemTy = elemTy;
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||
|
||||
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
|
||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
Type elemTy) {
|
||||
Value initTensor =
|
||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||
|
||||
Type fillValElemTy = elemTy;
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||
|
||||
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
|
||||
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
||||
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
||||
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
||||
|
|
|
@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
|||
}
|
||||
|
||||
LogicalResult BindSymbolicShapeOp::verify() {
|
||||
if (getShapeSymbols().empty())
|
||||
return emitOpError() << "requires non-empty shapeSymbols";
|
||||
if (getShapeSymbols().size() !=
|
||||
getShapeExpressions().getValue().getNumSymbols())
|
||||
return emitOpError()
|
||||
<< "requires equal number of shape symbol args and symbol args to "
|
||||
"the attached affine map, since they are 1:1 mapped";
|
||||
|
||||
for (auto symbol : getShapeSymbols()) {
|
||||
Operation *definingOp = symbol.getDefiningOp();
|
||||
|
|
|
@ -9200,6 +9200,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -10352,6 +10355,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
|
@ -11895,6 +11910,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int4 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %2#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
@ -14663,6 +14697,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
@ -15601,6 +15639,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"size must be less than or equal to {}\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_2 = torch.constant.str \"dimension out of range of {}\"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %5 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %15 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %14 = torch.aten.append.t %12, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %12 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -7298,6 +7298,85 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
|
||||
// op.
|
||||
class DecomposeAtenAdaptiveMaxPool1dOp
|
||||
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
|
||||
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
Value input = op.getSelf();
|
||||
std::optional<unsigned> maybeRank = getTensorRank(input);
|
||||
if (!maybeRank) {
|
||||
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||
}
|
||||
unsigned rank = *maybeRank;
|
||||
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rank - 1));
|
||||
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
|
||||
|
||||
Value outputShape = op.getOutputSize();
|
||||
SmallVector<Value> outputShapeSizesTorchInt;
|
||||
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
||||
Value outputSize = outputShapeSizesTorchInt[0];
|
||||
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
|
||||
int64_t outputSizeInt;
|
||||
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the output size of adaptive_max_pool1d must be a constant int");
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> kernelSize;
|
||||
if (outputSizeInt == 1) {
|
||||
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
kernelSize.push_back(
|
||||
inputShape[rank - 1] == kUnknownSize
|
||||
? inputSize
|
||||
: rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
|
||||
} else {
|
||||
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
||||
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"unimplemented: only support cases where input and output size are "
|
||||
"equal for non-unit output size");
|
||||
}
|
||||
kernelSize.push_back(constantOne);
|
||||
}
|
||||
|
||||
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
||||
Value strideList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantOne});
|
||||
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantZero});
|
||||
Value dialationList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantOne});
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
|
||||
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
||||
paddingSizeList, dialationList,
|
||||
/*ceil_mode=*/constantFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
|
||||
|
||||
|
@ -8720,6 +8799,77 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
||||
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
||||
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto self = op.getSelf();
|
||||
auto target = op.getTarget();
|
||||
auto posWeight = op.getPosWeight();
|
||||
auto weight = op.getWeight();
|
||||
auto reduction = op.getReduction();
|
||||
|
||||
Value loss;
|
||||
auto one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto _one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
|
||||
auto _target =
|
||||
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
|
||||
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
|
||||
_target, one, one);
|
||||
Value mm =
|
||||
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
|
||||
Value logSigm =
|
||||
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);
|
||||
|
||||
if (!isa<Torch::NoneType>(posWeight.getType())) {
|
||||
auto logWeight = rewriter.create<AtenAddScalarOp>(
|
||||
loc, posWeight.getType(),
|
||||
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
|
||||
one, one),
|
||||
one, one);
|
||||
loss = rewriter.create<AtenSubTensorOp>(
|
||||
loc, mm.getType(), mm,
|
||||
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
|
||||
logSigm),
|
||||
one);
|
||||
} else {
|
||||
loss =
|
||||
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
|
||||
}
|
||||
|
||||
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||
loss =
|
||||
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
|
||||
}
|
||||
|
||||
// apply loss reduction.
|
||||
int64_t reductionInt;
|
||||
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
|
||||
}
|
||||
|
||||
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value res;
|
||||
if (reductionInt == 1) {
|
||||
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
|
||||
} else if (reductionInt == 2) {
|
||||
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
|
||||
} else {
|
||||
res = loss;
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
||||
|
@ -9801,6 +9951,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
||||
|
@ -9856,6 +10007,8 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
||||
|
|
|
@ -530,11 +530,139 @@ public:
|
|||
none, none, none, none);
|
||||
return success();
|
||||
}
|
||||
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
|
||||
if (squeezeOp && resultTy.getSizes().size() == 1) {
|
||||
rewriter.replaceOp(op, squeezeOp.getSelf());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
|
||||
// [?,...,?,factor0,factor1] to unflatten, and views like
|
||||
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
|
||||
// possible to infer that all but last shared dim match
|
||||
// TODO: move this to an actual canonicalizer for view after deleting the
|
||||
// conflicting decompositions for flatten/unflatten -> view.
|
||||
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenViewOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> viewSizes;
|
||||
if (failed(getListOperands(op.getSize(), viewSizes)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "view size must be from a list construct");
|
||||
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
|
||||
if (!selfTy || !selfTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
|
||||
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
|
||||
if (!resultTy || !resultTy.hasSizes() ||
|
||||
resultTy.getSizes().size() != viewSizes.size())
|
||||
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
|
||||
int64_t inRank = selfTy.getSizes().size();
|
||||
int64_t outRank = resultTy.getSizes().size();
|
||||
|
||||
SmallVector<int64_t> sizes(selfTy.getSizes());
|
||||
int64_t endMatchingDim = -1;
|
||||
// input sizes vs. provided view sizes comparison loop
|
||||
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
|
||||
int64_t providedSize;
|
||||
bool providedStatic =
|
||||
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
|
||||
// if sizes[i] is static, it must match a constant in viewSizes[i]
|
||||
if (sizes[i] != Torch::kUnknownSize) {
|
||||
if (!providedStatic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: found static input dim, but unable to match "
|
||||
"provided view size on a constant. See position : " +
|
||||
std::to_string(i));
|
||||
if (providedSize != sizes[i]) {
|
||||
endMatchingDim = i;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// the remaining assumes sizes[i] is dynamic
|
||||
// if provided dim is static, we can't verify it is a flatten/unflatten
|
||||
// unless -1
|
||||
if (i == outRank - 1 && providedStatic && providedSize == -1) {
|
||||
endMatchingDim = i;
|
||||
break;
|
||||
}
|
||||
if (providedStatic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unexpected static view dim corresponding to dynamic input dim "
|
||||
"at position : " +
|
||||
std::to_string(i));
|
||||
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
|
||||
// if we don't have a size int op on self, fail
|
||||
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected dynamic view dim to come from a corresponding "
|
||||
"size.int op. See position : " +
|
||||
std::to_string(i));
|
||||
int64_t dim;
|
||||
// if the dim of the size int op doesn't match, fail
|
||||
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
|
||||
dim != i)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"size int op dim cannot be matched to current dim at position : " +
|
||||
std::to_string(i));
|
||||
// passing the previous checks means viewSizes[i] = aten.size.int(self,
|
||||
// i), so continue
|
||||
}
|
||||
// if all dims match and the ranks are equal, fold
|
||||
if (endMatchingDim == -1 && inRank == outRank) {
|
||||
rewriter.replaceOp(op, op.getSelf());
|
||||
return success();
|
||||
}
|
||||
if (endMatchingDim > -1 && inRank > outRank) {
|
||||
// only support flattening last dim
|
||||
if (endMatchingDim != outRank - 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: output has more than back dim mismatching");
|
||||
// flatten
|
||||
Value start =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||
Value end =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
|
||||
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
|
||||
op, resultTy, op.getSelf(), start, end);
|
||||
return success();
|
||||
}
|
||||
if (endMatchingDim > -1 && inRank < outRank) {
|
||||
// only support unflattening last dim
|
||||
if (endMatchingDim != inRank - 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input has more than back dim mismatching");
|
||||
// unflatten
|
||||
Value dim =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||
Value primList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), op.getSize().getType(),
|
||||
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
|
||||
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
|
||||
op, resultTy, op.getSelf(), dim, primList);
|
||||
return success();
|
||||
}
|
||||
// examples that might reach this:
|
||||
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
|
||||
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
|
||||
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
|
||||
", inRank=" + std::to_string(inRank) +
|
||||
", outRank=" + std::to_string(outRank));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
|
||||
public:
|
||||
|
@ -561,12 +689,18 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns
|
||||
.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
||||
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
|
||||
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
||||
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
|
||||
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
|
||||
FoldAtenWhereSelf, RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
|
||||
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
|
||||
RemoveUnusedPattern<Torch::AtenEqIntOp>,
|
||||
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
|
||||
RemoveUnusedPattern<Torch::AtenFullOp>,
|
||||
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||
RemoveUnusedPattern<Torch::ConstantBoolOp>,
|
||||
|
|
|
@ -90,7 +90,28 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::Float8_e5m2fnuz;
|
||||
if (isa<Float8E4M3FNUZType>(type))
|
||||
return torch_upstream::ScalarType::Float8_e4m3fnuz;
|
||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
std::string errorMsg = "Unhandled type in getScalarTypeForType: ";
|
||||
llvm::raw_string_ostream os(errorMsg);
|
||||
type.print(os);
|
||||
// os << "\nType ID: " << type.getTypeID();
|
||||
os << "\nType properties:";
|
||||
os << "\n Is integer: " << (type.isInteger() ? "yes" : "no");
|
||||
os << "\n Is float: "
|
||||
<< (type.isIntOrFloat() && !type.isInteger() ? "yes" : "no");
|
||||
os << "\n Is index: " << (type.isIndex() ? "yes" : "no");
|
||||
os << "\n Bit width: "
|
||||
<< (type.isIntOrFloat() ? std::to_string(type.getIntOrFloatBitWidth())
|
||||
: "N/A");
|
||||
os << "\n Is signless: " << (type.isSignlessInteger() ? "yes" : "no");
|
||||
os << "\n Is signed: " << (type.isSignedInteger() ? "yes" : "no");
|
||||
// special error message for unsigned integer
|
||||
if (type.isUnsignedInteger()) {
|
||||
os << "\n Is unsigned: yes";
|
||||
os << "\nUnsigned integer support is currently spotty. Please seeheck "
|
||||
"https://github.com/llvm/torch-mlir/issues/3720 "
|
||||
"for more details.";
|
||||
}
|
||||
llvm::report_fatal_error(llvm::StringRef(errorMsg));
|
||||
}
|
||||
Type Torch::getTypeForTorchType(
|
||||
MLIRContext *context, Type type,
|
||||
|
@ -257,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
||||
AtenPixelShuffleOp, AtenDiagonalOp>(op);
|
||||
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
|
|
|
@ -79,6 +79,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
#### General TorchDynamo/PyTorch errors
|
||||
# torch._dynamo.exc.Unsupported: Tensor.item
|
||||
"CumsumModule_basic",
|
||||
"CumprodModule_basic",
|
||||
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
||||
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
||||
# https://github.com/pytorch/pytorch/issues/89629
|
||||
|
@ -432,6 +433,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
"CumsumModule_basic",
|
||||
"CumprodModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
|
@ -504,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
# Error: `aten.as_strided` op is not supported
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
|
@ -588,6 +591,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AdaptiveAvgPool3dDynamic_basic",
|
||||
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
||||
"AdaptiveMaxPool1dDynamic_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"AdaptiveMaxPool1dStatic_basic",
|
||||
"AdaptiveMaxPool2dDynamicNoBatch_basic",
|
||||
"AdaptiveMaxPool2dDynamicWithIndices_basic",
|
||||
|
@ -666,6 +670,10 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
"CumsumModule_basic",
|
||||
"CumprodModule_basic",
|
||||
"CumprodInputDtypeInt32Module_basic",
|
||||
"CumprodStaticModule_basic",
|
||||
"CumprodStaticNegativeDimModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
|
@ -808,10 +816,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"RandnLikeDtypeModule_basic",
|
||||
"RandnLikeModule_basic",
|
||||
"RandnModule_basic",
|
||||
"ReduceAllDimBool_basic",
|
||||
"ReduceAllDimEmpty_basic",
|
||||
"ReduceAllDimFloat_basic",
|
||||
"ReduceAllDimInt_basic",
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
|
@ -829,18 +833,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ReplicationPad2dModule_top0",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
# need aten.all.dim lowering to stablehlo
|
||||
"SafeSoftmaxModule_basic",
|
||||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMeanModule",
|
||||
|
@ -926,6 +919,11 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"GCDBatchedModule_I32",
|
||||
"GCDDynamicModule_I32",
|
||||
"GCDModule_I32",
|
||||
"Unfold_Module_basic",
|
||||
"Unfold_Module_Rank_4",
|
||||
"Unfold_Module_Rank_Zero_basic",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
||||
|
@ -1059,6 +1057,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"ContiguousModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
@ -1079,6 +1078,9 @@ STABLEHLO_PASS_SET = {
|
|||
"CumsumInputDtypeInt32Module_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"CumprodInputDtypeInt32Module_basic",
|
||||
"CumprodStaticModule_basic",
|
||||
"CumprodStaticNegativeDimModule_basic",
|
||||
"DetachModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
|
@ -1425,6 +1427,7 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
|
@ -1671,6 +1674,35 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
"AtenLinalgCrossInt_basic",
|
||||
"AtenLinalgCrossNegativeDim_basic",
|
||||
"BinaryCrossEntropyWithLogitsStaticModule_basic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
|
@ -1814,7 +1846,6 @@ TOSA_PASS_SET = {
|
|||
"ArangeStartOutModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
"ArangeFloatModule_basic",
|
||||
|
@ -2115,7 +2146,6 @@ TOSA_PASS_SET = {
|
|||
"NormScalarOptDimModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"NumpyTRank1Module_basic",
|
||||
"NumpyTRank2Module_basic",
|
||||
"NumpyTRankNDynamicModule_basic",
|
||||
|
@ -2127,7 +2157,6 @@ TOSA_PASS_SET = {
|
|||
"OnesModuleInt_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
|
@ -2166,7 +2195,6 @@ TOSA_PASS_SET = {
|
|||
"ScalarTensorInt64Module_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SiluModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
|
@ -2348,6 +2376,13 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
@ -2588,6 +2623,7 @@ ONNX_XFAIL_SET = {
|
|||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
"SliceCopy_Module_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"StdCorrectionLargeInputModule_basic",
|
||||
"TupleModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
|
@ -2757,6 +2793,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
|
@ -3071,7 +3108,6 @@ ONNX_XFAIL_SET = {
|
|||
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||
"ScatterReduceIntMinModuleIncludeSelf",
|
||||
"ScatterValueFloatModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
# Failure - onnx_lowering: onnx.ScatterND
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DIntAccumulateModule_basic",
|
||||
|
@ -3107,6 +3143,10 @@ ONNX_XFAIL_SET = {
|
|||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
"CumprodModule_basic",
|
||||
"CumprodInputDtypeInt32Module_basic",
|
||||
"CumprodStaticModule_basic",
|
||||
"CumprodStaticNegativeDimModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
|
@ -3132,6 +3172,11 @@ ONNX_XFAIL_SET = {
|
|||
"GCDBatchedModule_I32",
|
||||
"GCDDynamicModule_I32",
|
||||
"GCDModule_I32",
|
||||
"Unfold_Module_Rank_4",
|
||||
"Unfold_Module_Rank_Zero_basic",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
@ -3170,6 +3215,18 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
|||
"AtenIntMM_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
|
||||
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
}
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
}
|
||||
|
||||
|
||||
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
|
@ -3197,17 +3254,30 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
"HstackBasicIntModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
|
@ -3220,14 +3290,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseRreluEvalModule_basic",
|
||||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
|
@ -3294,12 +3362,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossDynamic_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
"AtenLinalgCrossInt_basic",
|
||||
"AtenLinalgCrossNegativeDim_basic",
|
||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
|
@ -3312,8 +3374,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenMmQuint8_basic",
|
||||
"AtenRealView128Module_basic",
|
||||
"AtenRealView64Module_basic",
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
|
@ -3355,6 +3415,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dQInt8Module_depthwise",
|
||||
"Conv2dQInt8Module_grouped",
|
||||
|
@ -3383,18 +3444,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"CumsumModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"CumprodModule_basic",
|
||||
"CumprodInputDtypeInt32Module_basic",
|
||||
"CumprodStaticModule_basic",
|
||||
"CumprodStaticNegativeDimModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
"DeterminantModule_F32",
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
"DiagonalModule_with_dims",
|
||||
"DiagonalModule_with_dims_and_offset",
|
||||
"DiagonalModule_with_negative_dims",
|
||||
"DiagonalModule_with_offset",
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
|
@ -3478,20 +3535,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"EqIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"ExponentialModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"Fill_TensorFloat32WithFloat32_basic",
|
||||
"Fill_TensorFloat32WithFloat64_basic",
|
||||
"Fill_TensorFloat32WithInt64_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithFloat32_basic",
|
||||
"Fill_TensorFloat64WithFloat64_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"Fill_TensorFloat64WithInt64_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"FullLikeModuleInt2D_basic",
|
||||
"FullLikeModuleInt3D_basic",
|
||||
|
@ -3547,15 +3590,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicInputSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -3753,6 +3788,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
@ -3808,11 +3844,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ToCopyWithDTypeModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"TraceModule_basic",
|
||||
"TraceModule_empty",
|
||||
"TraceModule_nonsquare",
|
||||
"TraceSignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_empty",
|
||||
"TypeConversionI1ToF64Module_basic",
|
||||
"TypeConversionI1ToI32Module_basic",
|
||||
|
@ -3833,9 +3865,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"VarMeanUnbiasedModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ZeroFloat32Module_basic",
|
||||
"ZeroInt32Module_basic",
|
||||
"ZeroInt64Module_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
}
|
||||
|
||||
|
@ -3848,6 +3878,15 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"TrilIndicesAllZerosModule_basic",
|
||||
"TriuIndicesAllZerosModule_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ReduceAllDimFloatModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
|
@ -3877,7 +3916,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
|
@ -4010,8 +4048,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"AtenPolarDoubleModule_basic",
|
||||
"AtenRealView128Module_basic",
|
||||
"AtenRealView64Module_basic",
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
|
@ -4055,8 +4091,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"BucketizeTensorFloatModule_basic",
|
||||
"BucketizeTensorModule_basic",
|
||||
"BucketizeTensorOutInt32RightModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
|
@ -4075,6 +4109,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dBiasNoPaddingModule_basic",
|
||||
"Conv2dModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
|
@ -4115,6 +4150,10 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"CumsumModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"CumprodModule_basic",
|
||||
"CumprodInputDtypeInt32Module_basic",
|
||||
"CumprodStaticModule_basic",
|
||||
"CumprodStaticNegativeDimModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantModule_F32",
|
||||
"DeterminantBatchedModule_F32",
|
||||
|
@ -4265,7 +4304,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseWhereSelfModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"EmptyLikeMemoryFormatModule_basic",
|
||||
|
@ -4359,12 +4397,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicInputSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinModule3dInput_basic",
|
||||
|
@ -4382,10 +4414,8 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorMultiInputOneDim_basic",
|
||||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -4684,7 +4714,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScatterValueFloatModule_basic",
|
||||
"ScatterValueIntModule_basic",
|
||||
"SelectIntModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
|
@ -4696,6 +4725,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"SliceCopy_Module_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
|
|
|
@ -1445,6 +1445,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b
|
|||
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
|
@ -2001,6 +2004,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int =
|
|||
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
||||
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
|
||||
|
||||
def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]:
|
||||
scalar_shape: List[int] = []
|
||||
if reduction == 0:
|
||||
result_shape = upstream_shape_functions._copy(self)
|
||||
else:
|
||||
result_shape = scalar_shape
|
||||
return result_shape
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||
])
|
||||
|
@ -2937,6 +2948,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
|
|||
return torch.int64
|
||||
return self_dtype
|
||||
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
|
||||
def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if is_integer_dtype(self_dtype):
|
||||
return torch.int64
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
@ -4954,6 +4977,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
|
|||
return dtype
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
tensor_shapes=[(3,3)],
|
||||
|
@ -5543,7 +5570,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int,
|
|||
return torch.qint8
|
||||
return torch.qint32
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero.
|
||||
Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0.
|
||||
Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case.
|
||||
Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case.
|
||||
Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension.
|
||||
Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension.
|
||||
])
|
||||
def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]:
|
||||
ndim = len(self)
|
||||
|
||||
# Rank zero tensor
|
||||
if ndim == 0:
|
||||
assert dimension == 0, f"dimension out of range of {ndim}"
|
||||
assert size <= 1, "size must be less than or equal to 1"
|
||||
return [size]
|
||||
|
||||
dim = dimension
|
||||
if dim < 0:
|
||||
dim += ndim
|
||||
|
||||
assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}"
|
||||
|
||||
size_dim = self[dim]
|
||||
assert size <= size_dim, f"size must be less than or equal to {size_dim}"
|
||||
|
||||
num_blocks = (size_dim - size) // step + 1
|
||||
|
||||
out = upstream_shape_functions._copy(self)
|
||||
out[dim] = num_blocks
|
||||
out.append(size)
|
||||
return out
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1)
|
||||
)
|
||||
def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rad2deg : (Tensor) -> (Tensor)")
|
||||
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::real : (Tensor) -> (Tensor)")
|
||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||
|
@ -617,6 +618,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
|
@ -740,6 +744,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
||||
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
|
@ -986,6 +993,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
|
||||
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
|
||||
|
|
|
@ -42,7 +42,7 @@ def import_onnx(contents):
|
|||
# Import the ONNX model proto from the file contents:
|
||||
raw_model = onnx.load_from_string(contents)
|
||||
# since it does not affect current e2e tests, data_prop is left false here
|
||||
model_proto = onnx.shape_inference.infer_shapes(raw_model)
|
||||
model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)
|
||||
|
||||
# Import the ONNX module into an MLIR module:
|
||||
context = Context()
|
||||
|
|
|
@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class CumprodModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
ones = torch.ones([1], dtype=torch.int32)
|
||||
return torch.ops.aten.cumprod(val, ones.item())
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumprodModule())
|
||||
def CumprodModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
|
||||
class CumprodStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 7, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumprod(val, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumprodStaticModule())
|
||||
def CumprodStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
|
||||
class CumprodStaticNegativeDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 7, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumprod(val, dim=-1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
|
||||
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
|
||||
class CumprodInputDtypeInt32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 7, 4], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumprod(val, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
|
||||
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 7, 4).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenToDeviceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
|
|||
module.forward(inputVec, weight, bias)
|
||||
|
||||
|
||||
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 4, 6], torch.float32, True),
|
||||
([4, 1, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv1d(
|
||||
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
|
||||
)
|
||||
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||
inputVec = tu.rand(2, 4, 6)
|
||||
weight = torch.randn(4, 1, 3)
|
||||
module.forward(inputVec, weight)
|
||||
|
||||
|
||||
class Conv2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCreateComplexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.complex(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
|
||||
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.float32),
|
||||
tu.randint(4, high=10).type(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1783,6 +1783,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(1, 512, 10))
|
||||
|
||||
|
||||
class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return self.amp1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
|
||||
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7))
|
||||
|
||||
|
||||
# AdaptiveMaxPool2d
|
||||
|
||||
|
||||
|
|
|
@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
class ReduceAllDimFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a, dim=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
|
||||
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -2294,6 +2314,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
|
||||
|
||||
|
||||
class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([8, 2], torch.float32, True),
|
||||
([8, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input, target):
|
||||
return torch.ops.aten.binary_cross_entropy_with_logits(
|
||||
input, target, reduction=0
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule())
|
||||
def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 2), tu.rand(8, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
|
@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDtypeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([12, 1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
res = a.view(torch.int8)
|
||||
return res
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
|
||||
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(12, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1648,3 +1672,103 @@ class Rot90NegativeEvenRotationsModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
|
||||
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 5, 1, 7, 3))
|
||||
|
||||
|
||||
class Unfold_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 2, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module())
|
||||
def Unfold_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
||||
|
||||
class Unfold_Module_Negative_Dim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4, 4, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(-1, 2, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim())
|
||||
def Unfold_Module_Rank_4(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 4, 4))
|
||||
|
||||
|
||||
class Unfold_Module_Rank_Zero(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 1, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||
def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
|
||||
|
||||
class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 0, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
|
||||
|
||||
class Unfold_Module_Dynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(1, 2, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
|
||||
def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 4, 4))
|
||||
|
|
|
@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceStaticComplexInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4, 7], torch.complex64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x[0:5:1, 1:3:1, 2:4:1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
|
||||
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
4
setup.py
4
setup.py
|
@ -223,13 +223,13 @@ INSTALL_REQUIRES = [
|
|||
EXT_MODULES = [
|
||||
CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
|
||||
]
|
||||
NAME = "torch-mlir-core"
|
||||
NAME = "torch-mlir"
|
||||
|
||||
# If building PyTorch extensions, customize.
|
||||
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
|
||||
import torch
|
||||
|
||||
NAME = "torch-mlir"
|
||||
NAME = "torch-mlir-ext"
|
||||
INSTALL_REQUIRES.extend(
|
||||
[
|
||||
f"torch=={torch.__version__}".split("+", 1)[0],
|
||||
|
|
|
@ -16,10 +16,71 @@
|
|||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK-DAG: }
|
||||
// CHECK: }
|
||||
module {
|
||||
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
|
||||
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
|
||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias(
|
||||
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>,
|
||||
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>,
|
||||
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>,
|
||||
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>)
|
||||
// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) {
|
||||
// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK: torch.aten.flip
|
||||
// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) {
|
||||
// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK: torch.aten.flip
|
||||
// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)
|
||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs(
|
||||
// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>,
|
||||
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>,
|
||||
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>)
|
||||
// CHECK: torch.aten.transpose.int
|
||||
// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
|
||||
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
|
||||
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||
}
|
||||
|
|
|
@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
|
|||
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
|
||||
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
|
@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
|
|||
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
|
||||
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
|
||||
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
|
||||
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
|
||||
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
|
||||
// CHECK-NEXT: torch.runtime.assert %[[GE]]
|
||||
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
|
||||
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
|
||||
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
|
||||
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
|
||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
|
||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
|
||||
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
|
||||
// CHECK: return %[[EXPAND]]
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
||||
|
|
|
@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
|
||||
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "add"
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
|
||||
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
|
||||
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
|
||||
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
@ -2833,6 +2835,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>)
|
|||
return %0 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_shape_scalar
|
||||
func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
|
||||
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64>
|
||||
// CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
|
||||
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64>
|
||||
return %0: !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
|
||||
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
|
||||
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
%int0 = torch.constant.int 0
|
||||
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
|
@ -696,8 +696,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !
|
|||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: }
|
||||
|
@ -890,15 +890,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
|||
|
||||
// CHECK-LABEL: @torch.aten.max.dim$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
|
@ -1378,16 +1378,16 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
|
|||
|
||||
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
|
@ -1859,3 +1859,142 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
|
|||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0: !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
|
||||
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
|
||||
%dim1 = torch.constant.int 1
|
||||
%dim2 = torch.constant.int 0
|
||||
%offset = torch.constant.int -2
|
||||
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
|
||||
return %0 : !torch.vtensor<[5,6,2],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
|
||||
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
||||
return %0 : !torch.vtensor<[4,5,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
|
||||
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.flip(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
return %1 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.round(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
// CHECK-LABEL: func.func @scan_1d_inclusive(
|
||||
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
||||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
||||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
|
||||
|
@ -32,8 +32,10 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
|
||||
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
|
||||
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||
|
@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||
// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
|
||||
|
@ -62,14 +62,14 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc:
|
|||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||
func.func @scatter_update_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
|
@ -90,7 +90,8 @@ func.func @scatter_update_scalar_1D(
|
|||
// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||
// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
|
@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D(
|
|||
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
|
||||
// CHECK: tm_tensor.yield %[[ADD]] : i32
|
||||
// CHECK: }
|
||||
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||
func.func @scatter_add_scalar_1D(
|
||||
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||
|
|
|
@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
|
|||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
// expected-error @+1 {{op requires non-empty shapeSymbols}}
|
||||
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
|
||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verifier should not fail here since the op does not require shapeSymbols.
|
||||
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
||||
|
|
|
@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
|
|||
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
|
||||
return %slice : !torch.vtensor<[2],si32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @view_as_flatten_static
|
||||
func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> {
|
||||
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
|
||||
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
|
||||
%int1024 = torch.constant.int 1024
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list<int> -> !torch.vtensor<[?,?,1024],f32>
|
||||
return %3 : !torch.vtensor<[?,?,1024],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @view_as_unflatten_static
|
||||
func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> {
|
||||
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
|
||||
// CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
|
||||
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
|
||||
%int16 = torch.constant.int 16
|
||||
%int64 = torch.constant.int 64
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
|
||||
return %3 : !torch.vtensor<[?,?,16,64],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @view_as_flatten_dynamic
|
||||
func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
|
||||
%int-1 = torch.constant.int -1
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @unsqueeze_squeeze_combo
|
||||
func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int {
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
|
||||
// CHECK: return %0 : !torch.int
|
||||
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
|
||||
%4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
%5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
%6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64>
|
||||
%7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
%8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
%9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
%10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
%11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
|
||||
%12 = torch.aten.cat %11, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
|
||||
%13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
return %14 : !torch.int
|
||||
}
|
||||
|
|
|
@ -34,5 +34,5 @@ def test_enable_ir_printing():
|
|||
)
|
||||
|
||||
|
||||
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
|
||||
# CHECK: // -----// IR Dump After Inliner (inline)
|
||||
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|
||||
|
|
Loading…
Reference in New Issue