torch-mlir/include/npcomp/Dialect/TorchConversion/Transforms/Passes.td

77 lines
3.0 KiB
TableGen
Raw Normal View History

Add TorchToIREE and factor out TorchConversion dialect. This converts a basic list op (torch.prim.ListConstruct) to the IREE dialect. ``` def forward(self, x: float): return [x, x] ``` turns into: ``` builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> { %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float> return %0 : !torch.list<!torch.float> } ``` which turns into: ``` builtin.func @forward(%arg0: f64) -> !iree.list<f64> { %c1 = constant 1 : index %c0 = constant 0 : index %c2 = constant 2 : index %0 = iree.list.create %c2 : !iree.list<f64> iree.list.set %0[%c0], %arg0 : !iree.list<f64>, f64 iree.list.set %0[%c1], %arg0 : !iree.list<f64>, f64 return %0 : !iree.list<f64> } ``` As part of doing this, I realized that it was time to formalize the IR form that we reach right before running TorchTo{Linalg,Std,...}. We now call it the "Torch backend contract". We then lower the "Torch backend contract" to the "npcomp backend contract", which involves the new TorchConversion (`torch_c`) dialect, which holds ops that need to operate on both the npcomp backend types (e.g. builtin tensors, i1, IREE list, etc.) and the `!torch` types. This made more sense, as I realized that if I didn't factor out `torch_c` then the Torch dialect would have a dependency on IREE dialect (we previously didn't notice this was an issue because we only depended on `builtin` types), which seemed wrong to me. Recommended review order: - TorchToIREE.cpp / `TorchToIREE/basic.mlir` - Look at the new structure of createTorchScriptToNpcompBackendPipeline. It now lives in TorchConversion/Transforms/Passes.cpp and cleanly calls into `Torch::createTorchScriptToTorchBackendPipeline` for the frontend lowering to the Torch backend contract. - Mechanical change extracting `torch_c.{to,from}_{i1,i64,f64,builtin_tensor,iree_list}` into a new TorchConversion dialect, and a few passes specific to the lowering from the Torch backend contract to the npcomp backend contract. - Minor fixes to TorchToLinalg.cpp to use unconverted operands (now that we convert lists as part of operand materialization, we need to use the original operands). Also added test for AtenMaxPool2dOp and fixed m_TorchConstantIntList. - TmpDeleteDeadIREELists pass. Temporary pass for deleting dead IREE lists that are created as part of operand materialization for conv/max pool/avg pool ops in TorchToLinalg.
2021-08-12 05:40:08 +08:00
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_TORCHCONVERSION_PASSES
#define NPCOMP_TORCHCONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def VerifyInvariantsBeforeBackendLowering
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
let summary = "Verify invariants required by backend lowering";
let constructor =
"mlir::NPCOMP::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()";
let description = [{
This pass checks any invariants needed by the process of lowering the
`torch` dialect to the npcomp backend contract.
The most important invariant is that all tensors should be ranked and have
a known dtype. It is useful to catch this early because it usually
represents a simple bug in RefineTypes, but can manifest as many different
kinds of obscure symptoms during lowering.
TODO: This pass should probably be phrased as checking the
"torch backend contract" and moved to that dialect once we have more
substantial definition definition around what that layer is from an
"allowlist" perspective.
}];
}
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors";
let constructor = "mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}
def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
let summary = "Finalizes a partial conversion to builtin tensors";
let constructor =
"mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}
def TmpDeleteDeadIREELists
: Pass<"torch-tmp-delete-dead-lists", "FuncOp"> {
let summary = "Delete dead !iree.list ops";
let constructor =
"mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass()";
let description = [{
Runs a few patterns to delete dead !iree.list ops until IREE can support
running them. Currently, these will get materialized as part of conversions
for ops like AtenConv2dOp that have list operands, even though they are dead
(for those ops, we pattern match a specific case of static constant lists).
Currently, this will break execution of those tests because the IREE
side of these ops still doesn't work (nor is IREE able to delete them
itself).
TODO: Add support to IREE to run these ops E2E.
TODO: Remove this pass once IREE can run them e2e.
}];
}
#endif // NPCOMP_TORCHCONVERSION_PASSES