mirror of https://github.com/llvm/torch-mlir
handles 2,3,4 from https://github.com/llvm/torch-mlir/issues/1963 (#1964)
parent
a7449785ec
commit
953ea39cb5
|
@ -56,7 +56,12 @@ struct TorchLoweringPipelineOptions
|
|||
// to check for a specific set of legal ops to stop its iteration.
|
||||
ListOption<std::string> backendLegalOps{
|
||||
*this, "backend-legal-ops",
|
||||
llvm::cl::desc("List of ops to be considered legal for the backend.")};
|
||||
llvm::cl::desc("List of ops to be considered legal for the backend, such "
|
||||
"as 'aten.foo'.")};
|
||||
|
||||
Option<std::string> extraLibrary{
|
||||
*this, "extra-library",
|
||||
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
|
||||
};
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
|
@ -78,10 +83,12 @@ void createTorchSimplificationPipeline(
|
|||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
||||
void createTorchShapeRefinementPipeline(OpPassManager &pm);
|
||||
void createTorchShapeRefinementPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that refines dtype of tensor operations in the program.
|
||||
void createTorchDtypeRefinementPipeline(OpPassManager &pm);
|
||||
void createTorchDtypeRefinementPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
||||
|
@ -89,7 +96,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createReduceOpVariantsPass(StringRef extraLibrary);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||
|
||||
|
@ -100,14 +108,14 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createSimplifyShapeCalculationsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createReifyDtypeCalculationsPass(StringRef extraLibrary);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createSimplifyDtypeCalculationsPass();
|
||||
|
@ -120,13 +128,16 @@ createEraseModuleInitializerPass();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyBackendContractNoDecompositionsPass();
|
||||
|
||||
StringRef getAbstractInterpLibrary();
|
||||
|
||||
static const char kTorchOpPrefix[] = R"(torch.)";
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
|
|
|
@ -151,7 +151,13 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
|||
|
||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createReduceOpVariantsPass(/*extraLibrary=*/"")
|
||||
}];
|
||||
let options = [
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
"MLIR module for verifying custom op value semantics">,
|
||||
];
|
||||
let description = [{
|
||||
Replaces ops with other ops to reduce the number of variants that
|
||||
need to be handled elsewhere in the code.
|
||||
|
@ -240,7 +246,13 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
|||
|
||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||
let summary = "Reify shape calculations.";
|
||||
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createReifyShapeCalculationsPass(/*extraLibrary=*/"")
|
||||
}];
|
||||
let options = [
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
"MLIR module for splicing into the shape library">,
|
||||
];
|
||||
let description = [{
|
||||
}];
|
||||
}
|
||||
|
@ -255,7 +267,13 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
|
|||
|
||||
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
|
||||
let summary = "Reify dtype calculations.";
|
||||
let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createReifyDtypeCalculationsPass(/*extraLibrary=*/"")
|
||||
}];
|
||||
let options = [
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
"MLIR module for splicing into the dtype library">,
|
||||
];
|
||||
let description = [{
|
||||
}];
|
||||
}
|
||||
|
@ -291,7 +309,7 @@ def LowerToBackendContract
|
|||
let summary = "Perform simplifications until the backend contract is satisfied.";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{})
|
||||
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
|
||||
}];
|
||||
let description = [{
|
||||
This pass performs the bulk of the lowering of the program's computations
|
||||
|
@ -335,7 +353,9 @@ def LowerToBackendContract
|
|||
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
||||
"Decompose ops.">,
|
||||
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
|
||||
"List of ops to be considered legal for the backend.">
|
||||
"List of ops to be considered legal for the backend, such as 'aten.foo'.">,
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
"MLIR module for splicing into the abstract interpretation library">,
|
||||
|
||||
];
|
||||
// TODO: Debug why this is needed, even though the input program has func.func
|
||||
|
|
|
@ -3890,7 +3890,7 @@ private:
|
|||
// on `Operation *` are not allowed, since there is no way of telling if
|
||||
// that pattern will match on an op in the `legalOpsSet` or not.
|
||||
assert(opName && "All decomposition patterns must target a single op");
|
||||
if (!legalOpsSet.contains(opName->getStringRef()))
|
||||
if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix)))
|
||||
patterns.add<DecomposePattern>(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||
|
||||
|
@ -31,7 +32,7 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||
ConversionTarget &target,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
llvm::StringSet<> backendLegalOps);
|
||||
|
||||
static LogicalResult checkType(Operation *op, Type type,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
|
@ -246,11 +247,11 @@ static bool satisfiesBackendContract(ModuleOp module,
|
|||
// Explicitly set ops and dialects allowed and not allowed in backend contract.
|
||||
static ConversionTarget
|
||||
getBackendContractTarget(MLIRContext *context, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps) {
|
||||
llvm::StringSet<> backendLegalOpsSet) {
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
|
||||
if (decompose)
|
||||
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
|
||||
markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
|
||||
return target;
|
||||
}
|
||||
|
||||
|
@ -260,21 +261,27 @@ class LowerToBackendContractPass
|
|||
public:
|
||||
LowerToBackendContractPass() = default;
|
||||
LowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps) {
|
||||
ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary) {
|
||||
this->maxIterations = maxIterations;
|
||||
this->decompose = decompose;
|
||||
this->backendLegalOps = backendLegalOps;
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
backendLegalOpsSet.clear();
|
||||
backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end());
|
||||
ConversionTarget target =
|
||||
getBackendContractTarget(context, decompose, backendLegalOps);
|
||||
getBackendContractTarget(context, decompose, backendLegalOpsSet);
|
||||
|
||||
OpPassManager pm(module.getOperationName());
|
||||
TorchLoweringPipelineOptions options;
|
||||
options.decompose = decompose;
|
||||
options.backendLegalOps = backendLegalOps;
|
||||
options.extraLibrary = extraLibrary;
|
||||
createTorchSimplificationPipeline(pm, options);
|
||||
|
||||
int i = 0;
|
||||
|
@ -301,6 +308,8 @@ public:
|
|||
<< " iterations of the simplification pipeline\n";
|
||||
});
|
||||
}
|
||||
private:
|
||||
llvm::StringSet<> backendLegalOpsSet;
|
||||
};
|
||||
|
||||
class VerifyBackendContractNoDecompositionsPass
|
||||
|
@ -312,7 +321,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target =
|
||||
getBackendContractTarget(context, /*decompose*/false,
|
||||
/*backendLegalOps*/{});
|
||||
/*backendLegalOpsSet*/{});
|
||||
|
||||
if (!satisfiesBackendContract(getOperation(), target,
|
||||
/*actuallyEmitDiagnostics=*/true)) {
|
||||
|
@ -324,9 +333,10 @@ public:
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps) {
|
||||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
||||
backendLegalOps);
|
||||
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary) {
|
||||
return std::make_unique<LowerToBackendContractPass>(
|
||||
maxIterations, decompose, backendLegalOps, extraLibrary);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -337,9 +347,9 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
|||
// The backend contract guarantees that ops with decompositions available will
|
||||
// be decomposed. The only way to have an op reach the backend contract without
|
||||
// getting decomposed is by having the user explicitly specify that op in the
|
||||
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
|
||||
// `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore,
|
||||
// here we mark as illegal all ops with decompositions except for those in
|
||||
// `backendLegalOps`.
|
||||
// `backendLegalOpsSet`.
|
||||
//
|
||||
// The legality check takes place here instead of in the `DecomposeComplexOps`
|
||||
// pass for two reasons:
|
||||
|
@ -352,7 +362,7 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
|||
// decompositions explicit in this file
|
||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||
ConversionTarget &target,
|
||||
ArrayRef<std::string> backendLegalOps) {
|
||||
llvm::StringSet<> backendLegalOpsSet) {
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||
|
@ -463,7 +473,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
for (std::string opName : backendLegalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
}
|
||||
target.addDynamicallyLegalOp<OperatorOp>(
|
||||
[backendLegalOpsSet](OperatorOp opOp) {
|
||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||
return backendLegalOpsSet.contains(opName);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ void mlir::torch::registerTorchPasses() {
|
|||
"torch-simplification-pipeline",
|
||||
"Pipeline simplifying computations in the program.",
|
||||
mlir::torch::Torch::createTorchSimplificationPipeline);
|
||||
mlir::PassPipelineRegistration<>(
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
|
||||
mlir::torch::Torch::createTorchShapeRefinementPipeline);
|
||||
}
|
||||
|
@ -66,7 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// Perform the bulk of lowering to the backend contract.
|
||||
// See the pass documentation for more information.
|
||||
pm.addPass(createLowerToBackendContractPass(
|
||||
options.maxIterations, options.decompose, options.backendLegalOps));
|
||||
options.maxIterations, options.decompose, options.backendLegalOps,
|
||||
options.extraLibrary));
|
||||
}
|
||||
|
||||
// A simplification pipeline to establish the invariants of the backend
|
||||
|
@ -108,7 +109,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
createReduceOpVariantsPass(options.extraLibrary));
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Remove dead global slots.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
|
@ -121,8 +123,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// This should be run before RefineTypes (which primarily does dtype
|
||||
// inference), because Torch type promotion rules actually depend on the shape
|
||||
// of the operand.
|
||||
createTorchShapeRefinementPipeline(pm);
|
||||
createTorchDtypeRefinementPipeline(pm);
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
createTorchDtypeRefinementPipeline(pm, options);
|
||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
|
@ -141,13 +143,15 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
|
||||
static void createRefinementPipeline(
|
||||
mlir::OpPassManager &pm,
|
||||
llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()>
|
||||
llvm::function_ref<
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>(llvm::StringRef)>
|
||||
reifyCalculationsPass,
|
||||
llvm::function_ref<
|
||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
|
||||
simplifyCalculationsPass) {
|
||||
simplifyCalculationsPass,
|
||||
const mlir::torch::Torch::TorchLoweringPipelineOptions &options) {
|
||||
// Reify the library functions for each op that is present in the library.
|
||||
pm.addPass(reifyCalculationsPass());
|
||||
pm.addPass(reifyCalculationsPass(options.extraLibrary));
|
||||
|
||||
// Inline the library functions to enable analysis and transformation.
|
||||
// TODO: Only inline library functions (this will currently inline
|
||||
|
@ -168,12 +172,14 @@ static void createRefinementPipeline(
|
|||
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
|
||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
|
||||
Torch::createSimplifyShapeCalculationsPass);
|
||||
Torch::createSimplifyShapeCalculationsPass, options);
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) {
|
||||
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
|
||||
Torch::createSimplifyDtypeCalculationsPass);
|
||||
Torch::createSimplifyDtypeCalculationsPass, options);
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -52,17 +53,39 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
operatorOpHasValueSemantics(OperatorOp opOp,
|
||||
std::optional<SymbolTable> extraLibrary) {
|
||||
if (!extraLibrary.has_value())
|
||||
return false;
|
||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
|
||||
LibraryFunctionKind::HasValueSemantics) +
|
||||
Twine(opName))
|
||||
.str();
|
||||
auto libFunc = extraLibrary->lookup<func::FuncOp>(libFuncName);
|
||||
return bool(libFunc);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||
// immutable tensors.
|
||||
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
||||
public:
|
||||
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
||||
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
|
||||
const std::optional<SymbolTable>& extraLibrary)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
|
||||
this->extraLibrary = extraLibrary;
|
||||
}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>())
|
||||
if (isa<OperatorOp>(op)) {
|
||||
if (!operatorOpHasValueSemantics(cast<OperatorOp>(op), extraLibrary)) {
|
||||
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||
}
|
||||
} else if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||
}
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
// Convert all operands.
|
||||
|
@ -160,6 +183,8 @@ public:
|
|||
rewriter.finalizeRootUpdate(op);
|
||||
return success();
|
||||
}
|
||||
private:
|
||||
std::optional<SymbolTable> extraLibrary;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -241,11 +266,30 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
|
|||
}
|
||||
|
||||
namespace {
|
||||
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||
struct ReduceOpVariantsPass
|
||||
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||
ReduceOpVariantsPass() = default;
|
||||
ReduceOpVariantsPass(StringRef extraLibrary) {
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
|
||||
OwningOpRef<ModuleOp> extraLibraryModule =
|
||||
ModuleOp::create(UnknownLoc::get(context));
|
||||
std::optional<SymbolTable> extraLibraryModuleSymTable = std::nullopt;
|
||||
if (!extraLibrary.empty()) {
|
||||
if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) {
|
||||
emitError(getOperation()->getLoc(),
|
||||
"Failed to load extra-library file at " + extraLibrary);
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
extraLibraryModuleSymTable =
|
||||
SymbolTable(extraLibraryModule->getOperation());
|
||||
}
|
||||
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(
|
||||
context, extraLibraryModuleSymTable);
|
||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||
|
@ -253,8 +297,12 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||
Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||
(isa<OperatorOp>(op) &&
|
||||
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||
extraLibraryModuleSymTable))) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
// TODO: Make this an allowlist based on a closed torch dialect
|
||||
// type system.
|
||||
|
@ -281,6 +329,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createReduceOpVariantsPass() {
|
||||
return std::make_unique<ReduceOpVariantsPass>();
|
||||
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
|
||||
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
|
||||
}
|
||||
|
|
|
@ -8,18 +8,25 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ErrorOr.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
|
||||
std::string
|
||||
mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
|
||||
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||
return "__torch_mlir_shape_fn.";
|
||||
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||
return "__torch_mlir_dtype_fn.";
|
||||
else if (libFuncKind == LibraryFunctionKind::HasValueSemantics)
|
||||
return "__torch_mlir_has_value_semantics_fn.";
|
||||
llvm_unreachable(
|
||||
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
|
||||
}
|
||||
|
@ -73,6 +80,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
|||
// looking them up in the library.
|
||||
if (name.startswith("valsem."))
|
||||
name = name.drop_front(strlen("valsem."));
|
||||
if (isa<OperatorOp>(op))
|
||||
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
|
||||
std::string libFuncName =
|
||||
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
||||
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
||||
|
@ -288,3 +297,39 @@ FailureOr<Value> Torch::adjustFunctionArg(
|
|||
// Pass the operand as-is.
|
||||
return operand;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::torch::Torch::loadExtraLibrary(const std::string &filename,
|
||||
OwningOpRef<ModuleOp> &moduleToAppendTo) {
|
||||
auto ctx = moduleToAppendTo->getContext();
|
||||
assert(ctx && "Module should be fully initialized.");
|
||||
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
OwningOpRef<ModuleOp> module_ =
|
||||
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, ctx);
|
||||
if (!module_) {
|
||||
llvm::errs() << "Error can't load file " << filename << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
assert((moduleToAppendTo->getBodyRegion().empty() ||
|
||||
moduleToAppendTo->getBodyRegion().hasOneBlock()) &&
|
||||
"Module should have at most one block.");
|
||||
if (moduleToAppendTo->getBodyRegion().empty()) {
|
||||
moduleToAppendTo = std::move(module_);
|
||||
} else {
|
||||
Block *block = moduleToAppendTo->getBody(0);
|
||||
block->getOperations().splice(block->end(),
|
||||
module_->getBody(0)->getOperations());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -22,7 +22,12 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition };
|
||||
enum class LibraryFunctionKind {
|
||||
ShapeFunction,
|
||||
DtypeFunction,
|
||||
Decomposition,
|
||||
HasValueSemantics
|
||||
};
|
||||
|
||||
// Searches the function library for an abstract interpretation function for
|
||||
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
|
||||
|
@ -60,6 +65,16 @@ FailureOr<Value> adjustFunctionArg(
|
|||
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
||||
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
|
||||
[](OpBuilder &, Location, Value operand, Type) { return operand; });
|
||||
|
||||
std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind);
|
||||
|
||||
// Parse MLIR module at `filename` into a ModuleOp that will then
|
||||
// be appended to an existing, fully hydrated, ModuleOp; note the module
|
||||
// should have been instantiated with an associated context like so:
|
||||
// `OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));`
|
||||
LogicalResult loadExtraLibrary(const std::string &filename,
|
||||
OwningOpRef<ModuleOp> &moduleToAppendTo);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -61,13 +61,23 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
namespace {
|
||||
class ReifyDtypeCalculationsPass
|
||||
struct ReifyDtypeCalculationsPass
|
||||
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
|
||||
ReifyDtypeCalculationsPass() = default;
|
||||
ReifyDtypeCalculationsPass(StringRef extraLibrary) {
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getOperation();
|
||||
OwningOpRef<ModuleOp> library =
|
||||
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||
if (!extraLibrary.empty())
|
||||
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
|
||||
emitError(module->getLoc(),
|
||||
"Failed to load extra-library file at " + extraLibrary);
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Walk all the operations, and if we have a dtype function, wrap the op
|
||||
// in a `torch.dtype.calculate` op.
|
||||
|
@ -86,6 +96,6 @@ class ReifyDtypeCalculationsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
Torch::createReifyDtypeCalculationsPass() {
|
||||
return std::make_unique<ReifyDtypeCalculationsPass>();
|
||||
Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) {
|
||||
return std::make_unique<ReifyDtypeCalculationsPass>(extraLibrary);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -55,8 +56,12 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
namespace {
|
||||
class ReifyShapeCalculationsPass
|
||||
struct ReifyShapeCalculationsPass
|
||||
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
|
||||
ReifyShapeCalculationsPass() = default;
|
||||
ReifyShapeCalculationsPass(StringRef extraLibrary) {
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getOperation();
|
||||
|
@ -66,6 +71,12 @@ class ReifyShapeCalculationsPass
|
|||
// O(#ops in the program) ideally.
|
||||
OwningOpRef<ModuleOp> library =
|
||||
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||
if (!extraLibrary.empty())
|
||||
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
|
||||
emitError(module->getLoc(),
|
||||
"Failed to load extra-library file at " + extraLibrary);
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Walk all the operations, and if we have a shape function, wrap the op
|
||||
// in a `torch.shape.calculate` op.
|
||||
|
@ -84,6 +95,6 @@ class ReifyShapeCalculationsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createReifyShapeCalculationsPass() {
|
||||
return std::make_unique<ReifyShapeCalculationsPass>();
|
||||
mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) {
|
||||
return std::make_unique<ReifyShapeCalculationsPass>(extraLibrary);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,6 @@ class AddmmModule(torch.nn.Module):
|
|||
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
||||
|
||||
print(torch_mlir.compile(AddmmModule(), example_args,
|
||||
output_type="torch", backend_legal_ops=["torch.aten.addmm"]))
|
||||
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.addmm
|
||||
|
|
|
@ -240,8 +240,8 @@ class ExampleArgs:
|
|||
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
|
|||
use_tracing: bool = False,
|
||||
ignore_traced_shapes=False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
|
||||
verbose: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
|
@ -367,7 +368,11 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
|
||||
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
|
||||
if (_completely_unsupported_in_progress_extra_library is not None)
|
||||
else ""
|
||||
) + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim
|
||||
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.square
|
||||
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
import torch_mlir
|
||||
from torch_mlir_e2e_test.annotations import export, annotate_args
|
||||
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
def identity(x: torch.Tensor):
|
||||
return x
|
||||
|
||||
|
||||
goofy_lib = torch.library.Library("goofy", "DEF")
|
||||
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||
goofy_lib.impl("identity", identity)
|
||||
|
||||
|
||||
class CustomOpExampleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
b = 2 * a
|
||||
return torch.ops.goofy.identity(b)
|
||||
|
||||
|
||||
mod = CustomOpExampleModule()
|
||||
mod.eval()
|
||||
|
||||
abstract_interp_src = """\
|
||||
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
|
||||
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
|
||||
return %0#1 : !torch.int
|
||||
}
|
||||
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
|
||||
"""
|
||||
|
||||
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
|
||||
tmp.write(abstract_interp_src)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
mod,
|
||||
torch.ones(3, 4),
|
||||
output_type="torch",
|
||||
backend_legal_ops=["goofy.identity"],
|
||||
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
|
||||
)
|
||||
|
||||
print(module)
|
||||
|
||||
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
# CHECK: %{{.*}} = torch.constant.int 2
|
||||
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
|
||||
# CHECK: }
|
||||
# CHECK: }
|
Loading…
Reference in New Issue