2022-08-05 02:39:21 +08:00
|
|
|
//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "mlir/Pass/PassManager.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2022-08-05 02:39:21 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2023-03-25 10:50:01 +08:00
|
|
|
#include "llvm/ADT/StringSet.h"
|
2023-07-18 22:32:26 +08:00
|
|
|
#include "llvm/Support/Debug.h"
|
2022-08-05 02:39:21 +08:00
|
|
|
|
|
|
|
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Checking the backend contract.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|
|
|
ConversionTarget &target,
|
2023-03-25 10:50:01 +08:00
|
|
|
llvm::StringSet<> backendLegalOps);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
2022-08-05 02:39:21 +08:00
|
|
|
static LogicalResult checkType(Operation *op, Type type,
|
|
|
|
bool actuallyEmitDiagnostics) {
|
|
|
|
// Allow various scalar types that backends are expected to be able to handle.
|
2022-08-19 14:53:21 +08:00
|
|
|
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
|
|
|
Torch::DeviceType>())
|
2022-08-05 02:39:21 +08:00
|
|
|
return success();
|
|
|
|
|
|
|
|
// Backends are not expected to support dynamic computations on these types,
|
|
|
|
// but they frequently appear as parameters to ops which backends
|
|
|
|
// can statically pattern match and eliminate from the program.
|
|
|
|
// For example, a tensor operand might be optional, and the backend
|
|
|
|
// will pattern-match statically whether it is passed as a tensor or None.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::NoneType, Torch::StringType>(type))
|
2022-08-05 02:39:21 +08:00
|
|
|
return success();
|
|
|
|
|
|
|
|
// We blanket prohibit non-value-semantic tensors.
|
|
|
|
// All of our backends are currently based on value-semantic tensors, so
|
|
|
|
// we consider it our responsibility to lower all non-value-semantic tensors
|
|
|
|
// to value-semantic tensors.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<NonValueTensorType>(type)) {
|
2022-08-05 02:39:21 +08:00
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
return op
|
|
|
|
->emitError("unsupported by backend contract: non-value tensor type")
|
|
|
|
.attachNote()
|
|
|
|
.append("this is likely due to a missing case in the "
|
|
|
|
"MaximizeValueSemantics pass");
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// For value-semantic tensors, we require at least a known rank and dtype.
|
|
|
|
// We are not aware of a situation where our backends can handle an unranked
|
|
|
|
// tensor type or a tensor with a dynamic dtype.
|
|
|
|
//
|
|
|
|
// There are somewhat fundamental reasons for this. In particular, the problem
|
|
|
|
// of unranked codegen is completely different from the problem of ranked
|
|
|
|
// codegen (since ranked corresponds to a fixed loop nest structure). For all
|
|
|
|
// codegen systems we are aware of, the program must be reduced to operate
|
|
|
|
// on ranked tensors at some point in compilation, and we are not aware of
|
|
|
|
// any backend with a general solution to this problem before it reaches
|
|
|
|
// codegen. So we consider it our responsibility to eliminate unranked tensor
|
|
|
|
// from the program.
|
|
|
|
//
|
|
|
|
// We aren't aware of any backend with any infrastructure to represent dynamic
|
|
|
|
// dtypes, let alone transform and optimize them. Additionally, it is unlikely
|
|
|
|
// that any backend, even if it supports dynamic dtypes in some form, will
|
|
|
|
// have an sufficiently rich system for representing PyTorch type promotion
|
|
|
|
// rules. So we consider it our responsibility to ensure that all dtypes are
|
|
|
|
// statically known.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto tensorType = dyn_cast<ValueTensorType>(type)) {
|
2022-08-05 02:39:21 +08:00
|
|
|
if (!tensorType.hasSizes()) {
|
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
return op
|
|
|
|
->emitError(
|
|
|
|
"unsupported by backend contract: tensor with unknown rank")
|
|
|
|
.attachNote()
|
2022-12-14 00:25:41 +08:00
|
|
|
.append("this is likely due to a missing transfer function "
|
|
|
|
"in abstract_interp_lib_gen.py");
|
2022-08-05 02:39:21 +08:00
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
return op
|
|
|
|
->emitError(
|
|
|
|
"unsupported by backend contract: tensor with unknown dtype")
|
|
|
|
.attachNote()
|
2023-05-13 04:40:45 +08:00
|
|
|
.append("this is likely due to a missing transfer function in "
|
|
|
|
"abstract_interp_lib_gen.py");
|
2022-08-05 02:39:21 +08:00
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Optional types are also in the category of types which we don't expect
|
|
|
|
// backends to dynamically compute with, but they can be pattern matched
|
|
|
|
// in many cases that are practically necessary.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
2022-08-05 02:39:21 +08:00
|
|
|
// TODO: Be stricter about tensor types.
|
|
|
|
// See comment below for ListType.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<ValueTensorType>(optionalType.getContainedType()))
|
2022-08-05 02:39:21 +08:00
|
|
|
return success();
|
|
|
|
return checkType(op, optionalType.getContainedType(),
|
|
|
|
actuallyEmitDiagnostics);
|
|
|
|
}
|
|
|
|
// List types are also in the category of types which we don't expect
|
|
|
|
// backends to dynamically compute with, but they can be pattern matched
|
|
|
|
// in many cases that are practically necessary. For example, the
|
|
|
|
// strides of a convolution op are represented as a list.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto listType = dyn_cast<ListType>(type)) {
|
2022-08-05 02:39:21 +08:00
|
|
|
// TODO: Be stricter about tensor types.
|
|
|
|
// For the moment, there are cases (such as for torch.cat) where we end
|
|
|
|
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
|
|
|
|
// the contained type information. Somehow this slips through and works.
|
|
|
|
// We should be stricter about this and properly infer the contained type
|
|
|
|
// and shape.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<ValueTensorType>(listType.getContainedType()))
|
2022-08-05 02:39:21 +08:00
|
|
|
return success();
|
|
|
|
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
|
|
|
|
}
|
|
|
|
// Tuple types are also in the category of types which we don't expect
|
|
|
|
// backends to dynamically compute with, but they can be pattern matched
|
|
|
|
// in many cases that are practically necessary.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto tupleType = dyn_cast<Torch::TupleType>(type)) {
|
2022-08-05 02:39:21 +08:00
|
|
|
for (auto containedType : tupleType.getContainedTypes()) {
|
|
|
|
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Unsupported type.
|
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
return op->emitError("unsupported by backend contract: type ") << type;
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
static LogicalResult checkOpIsBackendLegal(Operation *op,
|
|
|
|
const ConversionTarget &target,
|
|
|
|
bool actuallyEmitDiagnostics) {
|
|
|
|
if (target.isLegal(op))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
return op->emitError("found an op that was marked as backend illegal")
|
|
|
|
.attachNote()
|
|
|
|
.append("this is likely due to DecomposeComplexOps being unable to "
|
|
|
|
"decompose this op");
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-05 02:39:21 +08:00
|
|
|
static bool satisfiesBackendContract(ModuleOp module,
|
2022-12-09 01:26:38 +08:00
|
|
|
const ConversionTarget &target,
|
2022-08-05 02:39:21 +08:00
|
|
|
bool actuallyEmitDiagnostics = false) {
|
|
|
|
// We do not permit `torch.global_slot`'s in the backend contract, since
|
|
|
|
// support for them is not widespread, and this does not align with PyTorch's
|
|
|
|
// more tracing-based direction.
|
|
|
|
//
|
|
|
|
// We just check for the GlobalSlotModuleInitializerOp since its verifier
|
|
|
|
// ensures that the set of global slots matches those initialized by the
|
|
|
|
// module initializer.
|
|
|
|
auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) {
|
|
|
|
if (actuallyEmitDiagnostics) {
|
|
|
|
// Report the error on the terminator to avoid dumping the whole
|
|
|
|
// initializer itself, which can have pages of ops in it.
|
|
|
|
op.getBody()
|
|
|
|
->getTerminator()
|
|
|
|
->emitError("unsupported by backend contract: module initializers")
|
|
|
|
.attachNote()
|
|
|
|
.append("this is likely due to InlineGlobalSlots being unable to "
|
|
|
|
"inline a global slot");
|
|
|
|
}
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
});
|
|
|
|
if (walkResult0.wasInterrupted())
|
|
|
|
return false;
|
|
|
|
|
2023-03-20 23:27:08 +08:00
|
|
|
// Check for unimplemented operators first to give more direct diagnostics.
|
|
|
|
walkResult0 = module.walk([&](Torch::OperatorOp op) {
|
|
|
|
if (llvm::all_of(op.getResults(), [&op](auto res) {
|
2024-01-30 01:59:33 +08:00
|
|
|
return succeeded(checkType(op.getOperation(), res.getType(),
|
|
|
|
/*actuallyEmitDiagnostics=*/false));
|
2023-03-20 23:27:08 +08:00
|
|
|
})) {
|
|
|
|
return WalkResult::advance();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (actuallyEmitDiagnostics) {
|
2024-01-30 01:59:33 +08:00
|
|
|
op->emitError(
|
|
|
|
"unsupported by backend contract: Unimplemented operator '" +
|
|
|
|
op.getName() + "'");
|
2023-03-20 23:27:08 +08:00
|
|
|
}
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
});
|
|
|
|
if (walkResult0.wasInterrupted())
|
|
|
|
return false;
|
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
// Check all the types of all Value's in the program and the legality of all
|
|
|
|
// the ops.
|
2022-08-05 02:39:21 +08:00
|
|
|
//
|
|
|
|
// A pre-order walk gives a more intuitive "first error".
|
|
|
|
// TODO: Should we report more than the first error?
|
|
|
|
// How do we avoid making it too spammy?
|
|
|
|
auto walkResult1 = module.walk<WalkOrder::PreOrder>([&](Block *block) {
|
|
|
|
for (BlockArgument arg : block->getArguments())
|
|
|
|
if (failed(checkType(block->getParentOp(), arg.getType(),
|
|
|
|
actuallyEmitDiagnostics))) {
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
2022-12-09 01:26:38 +08:00
|
|
|
for (Operation &op : *block) {
|
|
|
|
if (failed(checkOpIsBackendLegal(&op, target, actuallyEmitDiagnostics)))
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
|
2022-08-05 02:39:21 +08:00
|
|
|
for (OpResult result : op.getResults())
|
|
|
|
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
|
|
|
|
return WalkResult::interrupt();
|
2022-12-09 01:26:38 +08:00
|
|
|
}
|
2022-08-05 02:39:21 +08:00
|
|
|
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
if (walkResult1.wasInterrupted())
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2022-12-16 00:32:52 +08:00
|
|
|
// Explicitly set ops and dialects allowed and not allowed in backend contract.
|
|
|
|
static ConversionTarget
|
|
|
|
getBackendContractTarget(MLIRContext *context, bool decompose,
|
2023-03-25 10:50:01 +08:00
|
|
|
llvm::StringSet<> backendLegalOpsSet) {
|
2022-12-16 00:32:52 +08:00
|
|
|
ConversionTarget target(*context);
|
|
|
|
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
|
|
|
|
if (decompose)
|
2023-03-25 10:50:01 +08:00
|
|
|
markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
|
2022-12-16 00:32:52 +08:00
|
|
|
return target;
|
|
|
|
}
|
|
|
|
|
2022-08-05 02:39:21 +08:00
|
|
|
namespace {
|
|
|
|
class LowerToBackendContractPass
|
|
|
|
: public LowerToBackendContractBase<LowerToBackendContractPass> {
|
|
|
|
public:
|
|
|
|
LowerToBackendContractPass() = default;
|
2022-08-18 07:23:52 +08:00
|
|
|
LowerToBackendContractPass(int maxIterations, bool decompose,
|
2024-05-10 02:44:36 +08:00
|
|
|
bool shapeDtypeRefine,
|
2023-03-25 10:50:01 +08:00
|
|
|
ArrayRef<std::string> backendLegalOps,
|
|
|
|
StringRef extraLibrary) {
|
2022-08-05 02:39:21 +08:00
|
|
|
this->maxIterations = maxIterations;
|
|
|
|
this->decompose = decompose;
|
2024-05-10 02:44:36 +08:00
|
|
|
this->shapeDtypeRefine = shapeDtypeRefine;
|
2022-08-18 07:23:52 +08:00
|
|
|
this->backendLegalOps = backendLegalOps;
|
2023-03-25 10:50:01 +08:00
|
|
|
this->extraLibrary = extraLibrary.str();
|
2022-08-05 02:39:21 +08:00
|
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
|
|
ModuleOp module = getOperation();
|
2022-12-09 01:26:38 +08:00
|
|
|
MLIRContext *context = &getContext();
|
2023-03-25 10:50:01 +08:00
|
|
|
|
|
|
|
backendLegalOpsSet.clear();
|
|
|
|
backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end());
|
2022-12-16 00:32:52 +08:00
|
|
|
ConversionTarget target =
|
2023-03-25 10:50:01 +08:00
|
|
|
getBackendContractTarget(context, decompose, backendLegalOpsSet);
|
2022-08-05 02:39:21 +08:00
|
|
|
|
|
|
|
OpPassManager pm(module.getOperationName());
|
|
|
|
TorchLoweringPipelineOptions options;
|
|
|
|
options.decompose = decompose;
|
2024-05-10 02:44:36 +08:00
|
|
|
options.shapeDtypeRefine = shapeDtypeRefine;
|
2022-08-18 07:23:52 +08:00
|
|
|
options.backendLegalOps = backendLegalOps;
|
2023-03-25 10:50:01 +08:00
|
|
|
options.extraLibrary = extraLibrary;
|
2022-08-05 02:39:21 +08:00
|
|
|
createTorchSimplificationPipeline(pm, options);
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
do {
|
|
|
|
if (i++ == maxIterations) {
|
|
|
|
LLVM_DEBUG({
|
|
|
|
llvm::dbgs() << "LowerToBackendContractPass: "
|
|
|
|
<< "failed to satisfy backend contract after "
|
|
|
|
<< maxIterations
|
|
|
|
<< " iterations of the simplification pipeline\n";
|
|
|
|
});
|
|
|
|
// Show the diagnostics.
|
2022-12-09 01:26:38 +08:00
|
|
|
(void)satisfiesBackendContract(module, target,
|
2022-08-05 02:39:21 +08:00
|
|
|
/*actuallyEmitDiagnostics=*/true);
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (failed(runPipeline(pm, module)))
|
|
|
|
return signalPassFailure();
|
2022-12-09 01:26:38 +08:00
|
|
|
} while (!satisfiesBackendContract(module, target));
|
2022-08-05 02:39:21 +08:00
|
|
|
LLVM_DEBUG({
|
2024-04-28 05:08:09 +08:00
|
|
|
llvm::dbgs() << "LowerToBackendContractPass: " << "succeeded after " << i
|
2022-08-05 02:39:21 +08:00
|
|
|
<< " iterations of the simplification pipeline\n";
|
|
|
|
});
|
|
|
|
}
|
2024-01-30 01:59:33 +08:00
|
|
|
|
2023-03-25 10:50:01 +08:00
|
|
|
private:
|
|
|
|
llvm::StringSet<> backendLegalOpsSet;
|
2022-08-05 02:39:21 +08:00
|
|
|
};
|
2022-10-05 06:53:28 +08:00
|
|
|
|
2023-01-25 11:14:17 +08:00
|
|
|
class VerifyBackendContractNoDecompositionsPass
|
2024-01-30 01:59:33 +08:00
|
|
|
: public VerifyBackendContractNoDecompositionsBase<
|
|
|
|
VerifyBackendContractNoDecompositionsPass> {
|
2022-10-05 06:53:28 +08:00
|
|
|
public:
|
2023-01-25 11:14:17 +08:00
|
|
|
VerifyBackendContractNoDecompositionsPass() = default;
|
|
|
|
|
2022-10-05 06:53:28 +08:00
|
|
|
void runOnOperation() override {
|
2022-12-09 01:26:38 +08:00
|
|
|
MLIRContext *context = &getContext();
|
2022-12-16 00:32:52 +08:00
|
|
|
ConversionTarget target =
|
2024-01-30 01:59:33 +08:00
|
|
|
getBackendContractTarget(context, /*decompose*/ false,
|
|
|
|
/*backendLegalOpsSet*/ {});
|
2022-12-16 00:32:52 +08:00
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
if (!satisfiesBackendContract(getOperation(), target,
|
|
|
|
/*actuallyEmitDiagnostics=*/true)) {
|
2022-10-05 06:53:28 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
2022-08-05 02:39:21 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
2022-08-18 07:23:52 +08:00
|
|
|
mlir::torch::Torch::createLowerToBackendContractPass(
|
2024-05-10 02:44:36 +08:00
|
|
|
int maxIterations, bool decompose, bool shapeDtypeRefine,
|
|
|
|
ArrayRef<std::string> backendLegalOps, StringRef extraLibrary) {
|
2023-03-25 10:50:01 +08:00
|
|
|
return std::make_unique<LowerToBackendContractPass>(
|
2024-05-10 02:44:36 +08:00
|
|
|
maxIterations, decompose, shapeDtypeRefine, backendLegalOps,
|
|
|
|
extraLibrary);
|
2022-08-05 02:39:21 +08:00
|
|
|
}
|
2022-10-05 06:53:28 +08:00
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
2023-01-25 11:14:17 +08:00
|
|
|
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
|
|
|
return std::make_unique<VerifyBackendContractNoDecompositionsPass>();
|
2022-10-05 06:53:28 +08:00
|
|
|
}
|
2022-12-09 01:26:38 +08:00
|
|
|
|
|
|
|
// The backend contract guarantees that ops with decompositions available will
|
|
|
|
// be decomposed. The only way to have an op reach the backend contract without
|
|
|
|
// getting decomposed is by having the user explicitly specify that op in the
|
2023-03-25 10:50:01 +08:00
|
|
|
// `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore,
|
2022-12-09 01:26:38 +08:00
|
|
|
// here we mark as illegal all ops with decompositions except for those in
|
2023-03-25 10:50:01 +08:00
|
|
|
// `backendLegalOpsSet`.
|
2022-12-09 01:26:38 +08:00
|
|
|
//
|
|
|
|
// The legality check takes place here instead of in the `DecomposeComplexOps`
|
|
|
|
// pass for two reasons:
|
|
|
|
// 1. Makes sure the `DecomposeComplexOps` pass always succeeds, allowing it to
|
|
|
|
// run multiple times. This is needed for graphs where static information such
|
|
|
|
// as dtypes and shapes takes multiple iterations to propagate through the
|
|
|
|
// entire graph. `DecomposeComplexOps` pass failing would cause the entire
|
|
|
|
// `LowerToBackendContractPass` to fail
|
|
|
|
// 2. Makes the legality requirements in the backend contract for ops with
|
|
|
|
// decompositions explicit in this file
|
|
|
|
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|
|
|
ConversionTarget &target,
|
2023-03-25 10:50:01 +08:00
|
|
|
llvm::StringSet<> backendLegalOpsSet) {
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
|
|
|
target.addIllegalOp<Aten_SoftmaxOp>();
|
|
|
|
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
|
|
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
2024-04-28 11:47:43 +08:00
|
|
|
target.addIllegalOp<AtenLogSigmoidOp>();
|
2024-05-08 15:20:45 +08:00
|
|
|
target.addIllegalOp<AtenHardshrinkOp>();
|
|
|
|
target.addIllegalOp<AtenSoftshrinkOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenEmptyLikeOp>();
|
|
|
|
target.addIllegalOp<AtenOnesLikeOp>();
|
|
|
|
target.addIllegalOp<AtenZerosLikeOp>();
|
2023-03-11 09:25:25 +08:00
|
|
|
target.addIllegalOp<AtenStackOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenRollOp>();
|
|
|
|
target.addIllegalOp<AtenRepeatOp>();
|
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
|
|
|
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenExpandOp>();
|
|
|
|
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
|
|
|
target.addIllegalOp<AtenWhereScalarOp>();
|
|
|
|
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
|
|
|
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
2023-02-11 05:58:39 +08:00
|
|
|
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
2024-05-16 15:27:25 +08:00
|
|
|
target.addIllegalOp<AtenMaskedScatterOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenSizeOp>();
|
|
|
|
target.addIllegalOp<AtenReshapeOp>();
|
|
|
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
|
|
|
target.addIllegalOp<AtenTanhBackwardOp>();
|
2023-12-10 12:30:37 +08:00
|
|
|
target.addIllegalOp<AtenEinsumOp>();
|
2024-02-10 00:00:24 +08:00
|
|
|
target.addIllegalOp<AtenTraceOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenAddmmOp>();
|
|
|
|
target.addIllegalOp<AtenMeanOp>();
|
|
|
|
target.addIllegalOp<AtenMeanDimOp>();
|
2023-02-21 12:08:29 +08:00
|
|
|
target.addIllegalOp<AtenNormScalarOptDimOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenSelectIntOp>();
|
|
|
|
target.addIllegalOp<AtenMvOp>();
|
2024-03-14 03:17:22 +08:00
|
|
|
target.addIllegalOp<AtenLinalgCrossOp>();
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
target.addIllegalOp<AtenPixelShuffleOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenTOp>();
|
|
|
|
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
|
|
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> lhsRank = getTensorRank(op.getSelf());
|
|
|
|
std::optional<unsigned> rhsRank = getTensorRank(op.getOther());
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!lhsRank || !rhsRank)
|
|
|
|
return false;
|
2022-12-09 01:26:38 +08:00
|
|
|
// Make aten.matmul legal if the following condition is satisfied.
|
2022-12-13 00:56:28 +08:00
|
|
|
return (*lhsRank != 2 || *rhsRank != 2) && (*lhsRank != 3 || *rhsRank != 3);
|
2022-12-09 01:26:38 +08:00
|
|
|
});
|
|
|
|
target.addIllegalOp<AtenAddcmulOp>();
|
|
|
|
target.addIllegalOp<AtenAddcdivOp>();
|
2024-02-19 22:23:48 +08:00
|
|
|
target.addIllegalOp<AtenInstanceNormOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenLayerNormOp>();
|
|
|
|
target.addIllegalOp<AtenNativeLayerNormOp>();
|
2023-12-13 11:05:12 +08:00
|
|
|
target.addIllegalOp<AtenGroupNormOp>();
|
|
|
|
target.addIllegalOp<AtenNativeGroupNormOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenNativeBatchNormOp>();
|
|
|
|
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
|
|
|
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
2024-01-24 13:30:03 +08:00
|
|
|
target.addIllegalOp<AtenConvTbcOp>();
|
|
|
|
target.addIllegalOp<AtenConv1dOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenConv2dOp>();
|
2024-01-24 13:30:03 +08:00
|
|
|
target.addIllegalOp<AtenConv3dOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
|
|
|
target.addIllegalOp<AtenArangeOp>();
|
|
|
|
target.addIllegalOp<AtenArangeStartOp>();
|
2024-03-14 08:28:33 +08:00
|
|
|
target.addIllegalOp<AtenLinspaceOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenArgmaxOp>();
|
2023-12-06 22:45:30 +08:00
|
|
|
target.addIllegalOp<AtenArgminOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenSquareOp>();
|
|
|
|
target.addIllegalOp<AtenVarOp>();
|
|
|
|
target.addIllegalOp<AtenStdOp>();
|
|
|
|
target.addIllegalOp<Aten_UnsafeViewOp>();
|
|
|
|
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
|
|
|
target.addIllegalOp<AtenBernoulliOp>();
|
|
|
|
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
2023-02-16 01:06:29 +08:00
|
|
|
target.addIllegalOp<AtenBernoulliPOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenBernoulliTensorOp>();
|
2023-12-28 12:33:18 +08:00
|
|
|
target.addIllegalOp<AtenExponentialOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenZeroOp>();
|
2023-11-02 02:23:28 +08:00
|
|
|
target.addIllegalOp<AtenEyeOp>();
|
|
|
|
target.addIllegalOp<AtenEyeMOp>();
|
2024-01-16 14:29:34 +08:00
|
|
|
target.addIllegalOp<AtenNanToNumOp>();
|
2023-06-07 10:06:27 +08:00
|
|
|
target.addIllegalOp<AtenIsnanOp>();
|
2023-11-04 22:26:01 +08:00
|
|
|
target.addIllegalOp<AtenIsinfOp>();
|
2024-01-16 14:29:34 +08:00
|
|
|
target.addIllegalOp<AtenIsneginfOp>();
|
|
|
|
target.addIllegalOp<AtenIsposinfOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenRandLikeOp>();
|
|
|
|
target.addIllegalOp<AtenHardsigmoidOp>();
|
|
|
|
target.addIllegalOp<AtenRelu6Op>();
|
2023-08-25 22:42:29 +08:00
|
|
|
target.addIllegalOp<AtenEluOp>();
|
2024-03-15 08:53:29 +08:00
|
|
|
target.addIllegalOp<AtenFakeQuantizePerTensorAffineOp>();
|
2023-10-26 10:36:18 +08:00
|
|
|
target.addIllegalOp<AtenGluOp>();
|
2023-12-14 12:28:08 +08:00
|
|
|
target.addIllegalOp<AtenSeluOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenHardswishOp>();
|
|
|
|
target.addIllegalOp<AtenSoftplusOp>();
|
|
|
|
target.addIllegalOp<AtenSiluOp>();
|
|
|
|
target.addIllegalOp<AtenNewZerosOp>();
|
|
|
|
target.addIllegalOp<AtenNewOnesOp>();
|
|
|
|
target.addIllegalOp<AtenHardtanhOp>();
|
|
|
|
target.addIllegalOp<AtenFullOp>();
|
|
|
|
target.addIllegalOp<AtenLinearOp>();
|
|
|
|
target.addIllegalOp<AtenMishOp>();
|
|
|
|
target.addIllegalOp<AtenFullLikeOp>();
|
2023-09-12 22:29:08 +08:00
|
|
|
target.addIllegalOp<AtenNewFullOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenExpandAsOp>();
|
|
|
|
target.addIllegalOp<Aten_ToCopyOp>();
|
|
|
|
target.addIllegalOp<AtenDropoutOp>();
|
2023-06-27 14:19:33 +08:00
|
|
|
target.addIllegalOp<AtenNativeDropoutOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenNewEmptyOp>();
|
2024-05-08 22:44:57 +08:00
|
|
|
target.addIllegalOp<AtenIndexTensorOp>();
|
|
|
|
target.addIllegalOp<AtenIndexPutOp>();
|
|
|
|
target.addIllegalOp<Aten_IndexPutImplOp>();
|
2023-07-14 15:26:54 +08:00
|
|
|
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenPadOp>();
|
2024-03-29 08:05:00 +08:00
|
|
|
target.addIllegalOp<AtenPreluOp>();
|
2024-04-28 17:23:40 +08:00
|
|
|
target.addIllegalOp<AtenCeluOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
|
|
|
target.addIllegalOp<AtenToDeviceOp>();
|
2024-04-10 22:26:48 +08:00
|
|
|
target.addIllegalOp<AtenToPrimDeviceOp>();
|
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
update
parent 22e88d523b1970b2e904eb5421d49d987a3d255e
author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800
committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
add support for adaptive_pool_id
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update:
* update
---------
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-08-05 07:48:09 +08:00
|
|
|
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
|
|
|
target.addIllegalOp<AtenClampMinOp>();
|
2023-12-05 13:55:51 +08:00
|
|
|
target.addIllegalOp<AtenClampMinTensorOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenClampMaxOp>();
|
|
|
|
target.addIllegalOp<AtenBaddbmmOp>();
|
|
|
|
target.addIllegalOp<AtenFloorDivideOp>();
|
2024-04-16 04:45:10 +08:00
|
|
|
target.addIllegalOp<AtenFloorDivideScalarOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenNumpyTOp>();
|
|
|
|
target.addIllegalOp<AtenSelectScatterOp>();
|
|
|
|
target.addIllegalOp<AtenVarDimOp>();
|
|
|
|
target.addIllegalOp<AtenAmaxOp>();
|
|
|
|
target.addIllegalOp<AtenVarCorrectionOp>();
|
|
|
|
target.addIllegalOp<AtenStdDimOp>();
|
2022-12-22 13:02:40 +08:00
|
|
|
target.addIllegalOp<AtenStdCorrectionOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenNarrowOp>();
|
2023-07-20 16:46:44 +08:00
|
|
|
target.addIllegalOp<AtenNarrowTensorOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
|
|
|
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
2024-02-01 01:39:38 +08:00
|
|
|
target.addIllegalOp<AtenLerpScalarOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenMseLossOp>();
|
|
|
|
target.addIllegalOp<AtenRandintLowOp>();
|
2023-04-04 17:31:21 +08:00
|
|
|
target.addIllegalOp<AtenRandintOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenVarMeanCorrectionOp>();
|
|
|
|
target.addIllegalOp<PrimsConvertElementTypeOp>();
|
2023-01-11 14:01:45 +08:00
|
|
|
target.addIllegalOp<PrimsVarOp>();
|
|
|
|
target.addIllegalOp<PrimsSqrtOp>();
|
2023-09-02 02:13:58 +08:00
|
|
|
target.addIllegalOp<AtenRandOp>();
|
2022-12-09 01:26:38 +08:00
|
|
|
target.addIllegalOp<AtenRandnOp>();
|
|
|
|
target.addIllegalOp<AtenRandnGeneratorOp>();
|
2023-01-16 19:40:21 +08:00
|
|
|
target.addIllegalOp<AtenRandnLikeOp>();
|
2024-01-16 14:49:29 +08:00
|
|
|
target.addIllegalOp<AtenNormalFunctionalOp>();
|
2022-12-09 23:22:26 +08:00
|
|
|
target.addIllegalOp<AtenVarMeanOp>();
|
2023-11-08 15:28:30 +08:00
|
|
|
target.addIllegalOp<AtenCosineSimilarityOp>();
|
2024-04-24 14:32:33 +08:00
|
|
|
target.addIllegalOp<AtenTruncOp>();
|
2022-12-29 22:52:23 +08:00
|
|
|
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
2023-09-14 01:04:31 +08:00
|
|
|
target.addIllegalOp<AtenEmptyStridedOp>();
|
2023-02-03 10:20:47 +08:00
|
|
|
target.addIllegalOp<AtenBucketizeTensorOp>();
|
2022-11-16 13:57:58 +08:00
|
|
|
target.addIllegalOp<PrimsSqueezeOp>();
|
2023-01-02 22:34:39 +08:00
|
|
|
target.addIllegalOp<AtenMovedimIntOp>();
|
2023-04-11 16:02:28 +08:00
|
|
|
target.addIllegalOp<AtenOneHotOp>();
|
2023-03-15 16:00:03 +08:00
|
|
|
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
2023-04-26 15:14:06 +08:00
|
|
|
target.addIllegalOp<AtenVarMeanDimOp>();
|
2023-05-02 21:29:00 +08:00
|
|
|
target.addIllegalOp<AtenTopkOp>();
|
2023-06-01 11:38:50 +08:00
|
|
|
target.addIllegalOp<AtenScalarTensorOp>();
|
2022-10-16 05:46:06 +08:00
|
|
|
target.addIllegalOp<AtenScatterValueOp>();
|
2023-07-20 09:51:58 +08:00
|
|
|
target.addIllegalOp<AtenTypeAsOp>();
|
2023-08-04 09:05:34 +08:00
|
|
|
target.addIllegalOp<AtenTileOp>();
|
2023-11-05 11:38:36 +08:00
|
|
|
target.addIllegalOp<AtenReshapeAsOp>();
|
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like:
```
import torch
def my_triu(x, diag):
rows = torch.ops.aten.size(x, -2)
cols = torch.ops.aten.size(x, -1)
row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
col_indices = torch.ops.aten.arange(cols).unsqueeze(0)
cond = torch.ops.aten.ge(
col_indices, torch.ops.aten.add(row_indices, diag))
return torch.ops.aten.where(cond, x, 0)
x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```
---------
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
2023-11-29 10:35:26 +08:00
|
|
|
target.addIllegalOp<AtenTriuOp>();
|
2024-03-06 08:31:01 +08:00
|
|
|
target.addIllegalOp<AtenLinalgNormOp>();
|
2023-03-25 10:50:01 +08:00
|
|
|
for (auto &opName : backendLegalOpsSet) {
|
|
|
|
target.addLegalOp(
|
|
|
|
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
2022-12-09 01:26:38 +08:00
|
|
|
}
|
2023-03-25 10:50:01 +08:00
|
|
|
target.addDynamicallyLegalOp<OperatorOp>(
|
|
|
|
[backendLegalOpsSet](OperatorOp opOp) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
|
2023-03-25 10:50:01 +08:00
|
|
|
return backendLegalOpsSet.contains(opName);
|
|
|
|
});
|
2022-12-09 01:26:38 +08:00
|
|
|
}
|