mirror of https://github.com/llvm/torch-mlir
Add recognition/folder/lowering for aten::__is__, aten::ne.int, and aten::dim
Interestingly, TorchScript has its own op (`torch::jit::Operator`) registry separate from the dispatcher (it is a superset of the dispatcher). This is where the "prim" ops and some "aten" ops (that should probably be renamed to "prim") live. In particular, `aten::__is__` is in that latter category of "aten but really prim". This registry is also the source of truth for what the TorchScript interpreter calls into when it executes. The bulk of the "not part of the dispatcher" ops live inpull/213/head09feb5f579/torch/csrc/jit/runtime/register_prim_ops.cpp (L82)
And the registry itself lives in:09feb5f579/torch/csrc/jit/runtime/operator.cpp (L196)
This fold further reduces the IR of ResNet by folding away some more not-taken branches. These not-taken branches in ResNet require first-class handling of the list type which we don't yet have on any backend.
parent
7eb36b4ae7
commit
55c3cc6624
|
@ -23,14 +23,23 @@ namespace {
|
|||
static const char kGetRegisteredOpsDocstring[] =
|
||||
R"(Gets a data structure of all registered ops.
|
||||
|
||||
The returned data reflects the metadata available in the c10 dispatcher at
|
||||
the time of this call. It is meant for various code generation tools.
|
||||
The returned data reflects the metadata available in the Torch JIT's
|
||||
registry at the time of this call. It includes both the operators available
|
||||
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
|
||||
uses to implement auxiliary operations that in the non-TorchScript case
|
||||
are performed by Python itself.
|
||||
|
||||
This information is meant for various code generation tools.
|
||||
|
||||
Returns:
|
||||
A list of records, one for each op. Each record is a dict of the following:
|
||||
A list of records, one for each `torch::jit::Operator`. Known to the
|
||||
Torch JIT operator registry. Each record is a dict of the following:
|
||||
"name": tuple -> (qualified_name, overload)
|
||||
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
|
||||
or is a JIT-only op.
|
||||
"is_vararg": bool -> Whether the op accepts variable arguments
|
||||
"is_varret": bool -> Whether the op produces variable returns
|
||||
"is_mutable": bool -> Whether the op potentially mutates any operand
|
||||
"arguments" and "returns": List[Dict] -> Having keys:
|
||||
"type": str -> PyTorch type name as in op signatures
|
||||
"pytype": str -> PyType style type annotation
|
||||
|
@ -39,92 +48,77 @@ Returns:
|
|||
"alias_info": Dict -> Alias info with keys "before" and "after"
|
||||
)";
|
||||
|
||||
class LambdaOpRegistrationListener : public c10::OpRegistrationListener {
|
||||
public:
|
||||
using CallbackTy = std::function<void(const c10::OperatorHandle &)>;
|
||||
LambdaOpRegistrationListener(CallbackTy callback)
|
||||
: callback(std::move(callback)) {}
|
||||
void onOperatorRegistered(const c10::OperatorHandle &op) override {
|
||||
callback(op);
|
||||
}
|
||||
void onOperatorDeregistered(const c10::OperatorHandle &op) override {}
|
||||
|
||||
private:
|
||||
CallbackTy callback;
|
||||
};
|
||||
|
||||
py::list GetRegisteredOps() {
|
||||
py::list results;
|
||||
c10::Dispatcher &dispatcher = c10::Dispatcher::singleton();
|
||||
auto listener = std::make_unique<LambdaOpRegistrationListener>(
|
||||
[&](const c10::OperatorHandle &op) -> void {
|
||||
if (!op.hasSchema()) {
|
||||
// Legacy?
|
||||
return;
|
||||
}
|
||||
|
||||
py::dict record;
|
||||
{
|
||||
py::tuple name(2);
|
||||
name[0] = op.operator_name().name;
|
||||
name[1] = op.operator_name().overload_name;
|
||||
record["name"] = std::move(name);
|
||||
}
|
||||
// Walk the JIT operator registry to find all the ops that we might need
|
||||
// for introspection / ODS generation.
|
||||
// This registry contains a superset of the ops available to the dispatcher,
|
||||
// since the JIT has its own dispatch mechanism that it uses to implement
|
||||
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||
// as `aten::__is__`.
|
||||
for (const std::shared_ptr<torch::jit::Operator> &op :
|
||||
torch::jit::getAllOperators()) {
|
||||
const c10::FunctionSchema &schema = op->schema();
|
||||
|
||||
auto &schema = op.schema();
|
||||
record["is_vararg"] = schema.is_vararg();
|
||||
record["is_varret"] = schema.is_varret();
|
||||
record["is_mutable"] = schema.is_mutable();
|
||||
py::dict record;
|
||||
{
|
||||
py::tuple name(2);
|
||||
name[0] = schema.name();
|
||||
name[1] = schema.overload_name();
|
||||
record["name"] = std::move(name);
|
||||
}
|
||||
|
||||
py::list arguments;
|
||||
py::list returns;
|
||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
||||
py::dict argRecord;
|
||||
argRecord["name"] = arg.name();
|
||||
argRecord["type"] = arg.type()->str();
|
||||
argRecord["pytype"] = arg.type()->annotation_str();
|
||||
if (arg.N())
|
||||
argRecord["N"] = *arg.N();
|
||||
// TODO: If the default value becomes useful, switch on it and return
|
||||
// a real value, not just a string print.
|
||||
if (arg.default_value()) {
|
||||
std::stringstream sout;
|
||||
sout << *arg.default_value();
|
||||
argRecord["default_debug"] = sout.str();
|
||||
}
|
||||
if (arg.alias_info()) {
|
||||
py::dict aliasInfo;
|
||||
py::list before;
|
||||
py::list after;
|
||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
||||
before.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
||||
after.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||
aliasInfo["before"] = std::move(before);
|
||||
aliasInfo["after"] = std::move(after);
|
||||
argRecord["alias_info"] = std::move(aliasInfo);
|
||||
}
|
||||
record["is_c10_op"] = op->isC10Op();
|
||||
record["is_vararg"] = schema.is_vararg();
|
||||
record["is_varret"] = schema.is_varret();
|
||||
record["is_mutable"] = schema.is_mutable();
|
||||
|
||||
container.append(std::move(argRecord));
|
||||
};
|
||||
for (auto &argument : schema.arguments()) {
|
||||
addArgument(arguments, argument);
|
||||
py::list arguments;
|
||||
py::list returns;
|
||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
||||
py::dict argRecord;
|
||||
argRecord["name"] = arg.name();
|
||||
argRecord["type"] = arg.type()->str();
|
||||
argRecord["pytype"] = arg.type()->annotation_str();
|
||||
if (arg.N())
|
||||
argRecord["N"] = *arg.N();
|
||||
// TODO: If the default value becomes useful, switch on it and return
|
||||
// a real value, not just a string print.
|
||||
if (arg.default_value()) {
|
||||
std::stringstream sout;
|
||||
sout << *arg.default_value();
|
||||
argRecord["default_debug"] = sout.str();
|
||||
}
|
||||
if (arg.alias_info()) {
|
||||
py::dict aliasInfo;
|
||||
py::list before;
|
||||
py::list after;
|
||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
||||
before.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
for (auto &returnArg : schema.returns()) {
|
||||
addArgument(returns, returnArg);
|
||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
||||
after.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
record["arguments"] = std::move(arguments);
|
||||
record["returns"] = std::move(returns);
|
||||
results.append(std::move(record));
|
||||
});
|
||||
// Note: addRegistrationListener reports all currently registered ops
|
||||
// during the call and then incrementally reports newer ops until the RAII
|
||||
// return value is destroyed. Since we only want the current, surround in
|
||||
// a block so it immediately unregisters.
|
||||
{ dispatcher.addRegistrationListener(std::move(listener)); }
|
||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||
aliasInfo["before"] = std::move(before);
|
||||
aliasInfo["after"] = std::move(after);
|
||||
argRecord["alias_info"] = std::move(aliasInfo);
|
||||
}
|
||||
|
||||
container.append(std::move(argRecord));
|
||||
};
|
||||
for (auto &argument : schema.arguments()) {
|
||||
addArgument(arguments, argument);
|
||||
}
|
||||
for (auto &returnArg : schema.returns()) {
|
||||
addArgument(returns, returnArg);
|
||||
}
|
||||
record["arguments"] = std::move(arguments);
|
||||
record["returns"] = std::move(returns);
|
||||
results.append(std::move(record));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
|
|
@ -92,6 +92,12 @@ def generate_ops(g: "OpGenerator"):
|
|||
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
|
||||
f"{snakecase_to_camelcase(uname)}Op", uname)
|
||||
|
||||
g.print_banner("TorchScript primitive ops")
|
||||
g.ordinary_primitive_op("aten::__is__(t1,t2)", "IsOp", "__is__",
|
||||
has_folder=True)
|
||||
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
|
||||
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
|
||||
|
||||
# Convolution ops. Note that these are special in PyTorch and the importer,
|
||||
# and we model them after the signatures of the convolution_overrideable
|
||||
# ops (generic for non-CPU/GPU backends) but set the names according to
|
||||
|
@ -274,6 +280,37 @@ class OpGenerator:
|
|||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_primitive_op(self,
|
||||
kernel_sig: str,
|
||||
ods_name: str,
|
||||
op_name: str,
|
||||
traits: Sequence[str] = (),
|
||||
**kwargs):
|
||||
""""An ordinary op which might operate on a variety of non-tensor types."""
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.transforms(
|
||||
type_transforms={
|
||||
"Tensor": "AnyTorchImmutableTensor",
|
||||
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||
"int": "AnyTorchIntType",
|
||||
"int[]": "AnyTorchIntListType",
|
||||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
"t1": "AnyTorchType",
|
||||
"t2": "AnyTorchType",
|
||||
},
|
||||
flag_transforms={
|
||||
"Tensor": ["kImmutableTensor"],
|
||||
"Tensor?": ["kImmutableTensor"],
|
||||
},
|
||||
)
|
||||
opdef.emit()
|
||||
|
||||
def ordinary_inplace_op(self, kernel_sig: str, ods_name: str, op_name: str,
|
||||
**kwargs):
|
||||
"""In-place ops (ending in '_').
|
||||
|
@ -452,7 +489,8 @@ class InflightOpDef:
|
|||
override_arg_types: Sequence[str] = None,
|
||||
override_return_types: Sequence[str] = None,
|
||||
drop_arg_indices: Sequence[int] = (),
|
||||
drop_return_indices: Sequence[int] = ()):
|
||||
drop_return_indices: Sequence[int] = (),
|
||||
has_folder: bool = False):
|
||||
super().__init__()
|
||||
self.g = g
|
||||
self.kernel_sig = kernel_sig
|
||||
|
@ -466,6 +504,7 @@ class InflightOpDef:
|
|||
self.override_return_types = override_return_types
|
||||
self.drop_arg_indices = drop_arg_indices
|
||||
self.drop_return_indices = drop_return_indices
|
||||
self.has_folder = has_folder
|
||||
self.reg_record = g.get_reg_record(self.kernel_sig)
|
||||
self._emitted = False
|
||||
self._traceback = traceback.extract_stack()[0:-2]
|
||||
|
@ -548,7 +587,8 @@ class InflightOpDef:
|
|||
self.reg_record,
|
||||
ods_ins=self.ods_ins,
|
||||
ods_outs=self.ods_outs,
|
||||
traits=self.traits)
|
||||
traits=self.traits,
|
||||
has_folder=self.has_folder)
|
||||
self.g.impl_emitter.emit_kernel_methods(
|
||||
self.ods_name,
|
||||
self.reg_record,
|
||||
|
@ -608,7 +648,8 @@ class OdsEmitter(EmitterBase):
|
|||
ods_ins: List[Tuple[str, str]],
|
||||
ods_outs: List[Tuple[str, str]],
|
||||
traits: Sequence[str] = (),
|
||||
summary: Optional[str] = None):
|
||||
summary: Optional[str] = None,
|
||||
has_folder: bool = False):
|
||||
# Def first-line.
|
||||
full_traits = list(traits)
|
||||
full_traits.append(
|
||||
|
@ -636,6 +677,9 @@ class OdsEmitter(EmitterBase):
|
|||
self._emit_dag_list_body(ods_outs)
|
||||
self.print(");")
|
||||
|
||||
if has_folder:
|
||||
self.print("let hasFolder = 1;")
|
||||
|
||||
# Def last-line.
|
||||
self.print("}\n")
|
||||
|
||||
|
|
|
@ -7,5 +7,5 @@ import _torch_mlir
|
|||
|
||||
# This check is just for a built-in op that is unlikely to change (and is
|
||||
# otherwise insignificant).
|
||||
# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]}
|
||||
# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_c10_op': True, 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]}
|
||||
print('\n\n'.join([repr(r) for r in _torch_mlir.get_registered_ops()]))
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||
#define NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToStdPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
|
@ -20,6 +20,11 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
||||
}
|
||||
|
||||
def ConvertATenToStd : Pass<"convert-aten-to-std", "FuncOp"> {
|
||||
let summary = "Convert recognized ATen to Std ops";
|
||||
let constructor = "mlir::NPCOMP::createConvertATenToStdPass()";
|
||||
}
|
||||
|
||||
def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> {
|
||||
let summary = "Convert recognized ATen to Linalg ops";
|
||||
let description = [{
|
||||
|
|
|
@ -23,6 +23,7 @@ include "mlir/IR/OpBase.td"
|
|||
def ATen_Dialect : Dialect {
|
||||
let name = "aten";
|
||||
let cppNamespace = "::mlir::NPCOMP::aten";
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -868,6 +868,64 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// TorchScript primitive ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Torch::KernelMetadata IsOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &IsOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::__is__";
|
||||
m.addArgTypes({"t1", "t2"});
|
||||
m.addArgConversions({KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"bool"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata DimOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &DimOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::dim";
|
||||
m.addArgTypes({"Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"int"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata NeIntOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::ne";
|
||||
m.addArgTypes({"int", "int"});
|
||||
m.addArgConversions({KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"bool"});
|
||||
m.addReturnConversions({KVC::kNone});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -472,6 +472,43 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
|
|||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// TorchScript primitive ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
def aten_IsOp: aten_Op<"__is__", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::__is__";
|
||||
let arguments = (ins
|
||||
AnyTorchType:$self,
|
||||
AnyTorchType:$obj
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchBoolType
|
||||
);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def aten_DimOp: aten_Op<"dim", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::dim";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchIntType
|
||||
);
|
||||
}
|
||||
|
||||
def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::ne";
|
||||
let arguments = (ins
|
||||
AnyTorchIntType:$a,
|
||||
AnyTorchIntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchBoolType
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NN ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Patterns (as this grows, it should be organized into multiple files)
|
||||
// -----------------------------------------------------------------------------
|
||||
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||||
|
||||
// Note: Confusingly, ATen's "dim" means "number of dimensions" which is what
|
||||
// MLIR calls "rank".
|
||||
LogicalResult convertDimOp(aten::DimOp op, PatternRewriter &rewriter) {
|
||||
if (!op.getOperand().getType().isa<TensorType>())
|
||||
return rewriter.notifyMatchFailure(op, "must be tensor only");
|
||||
auto rank = rewriter.create<RankOp>(op->getLoc(), op.getOperand());
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op.getType(), rank);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) {
|
||||
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::ne,
|
||||
op->getOperand(0), op->getOperand(1));
|
||||
rewriter.replaceOpWithNewOp<Basicpy::BoolCastOp>(op, op.getType(), i1);
|
||||
return success();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
class ConvertATenToStd : public ConvertATenToStdBase<ConvertATenToStd> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||
}
|
||||
|
||||
FrozenRewritePatternSet getPatterns() {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(convertDimOp);
|
||||
patterns.add(convertNeIntOp);
|
||||
return std::move(patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertATenToStdPass() {
|
||||
return std::make_unique<ConvertATenToStd>();
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
add_npcomp_conversion_library(NPCOMPATenToStd
|
||||
ATenToStd.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToStd
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
NPCOMPATenDialect
|
||||
)
|
|
@ -1,4 +1,5 @@
|
|||
add_subdirectory(ATenToLinalg)
|
||||
add_subdirectory(ATenToStd)
|
||||
add_subdirectory(ATenToTCF)
|
||||
add_subdirectory(BasicpyToStd)
|
||||
add_subdirectory(NumpyToTCF)
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "npcomp/Conversion/Passes.h"
|
||||
|
||||
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||
|
|
|
@ -10,7 +10,9 @@
|
|||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -96,17 +98,25 @@ void ATenDialect::printType(mlir::Type type, DialectAsmPrinter &os) const {
|
|||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
Operation *ATenDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
// Bool (i1 -> !basicpy.BoolType).
|
||||
if (type.isa<Basicpy::BoolType>()) {
|
||||
auto i1Value = value.dyn_cast<IntegerAttr>();
|
||||
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
|
||||
return builder.create<Basicpy::BoolConstantOp>(loc, type, i1Value);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ATenDialect::initialize() {
|
||||
addTypes<ATenListType>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
>();
|
||||
getContext()->getOrLoadDialect("torch");
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
//===- ATenDialect.cpp ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::aten;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult IsOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto lhsType = self().getType();
|
||||
auto rhsType = obj().getType();
|
||||
// If either type is a NoneType, make it be the lhsType.
|
||||
if (rhsType.isa<Basicpy::NoneType>())
|
||||
std::swap(lhsType, rhsType);
|
||||
// TODO: Implement and use subtype infra for this.
|
||||
// If neither type is a subtype of the other, then the result is false.
|
||||
if (lhsType.isa<Basicpy::NoneType>() && !rhsType.isa<Torch::OptionalType>())
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), 0);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"
|
|
@ -1,5 +1,6 @@
|
|||
add_npcomp_dialect_library(NPCOMPATenDialect
|
||||
ATenDialect.cpp
|
||||
ATenOps.cpp
|
||||
ATenDialectOpStats.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
||||
|
@ -106,6 +107,9 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
// Recognize ATen kernels. This is a totally local transformation that
|
||||
// we want to run as soon as possible.
|
||||
pm.addNestedPass<FuncOp>(aten::createRecognizeKernelsPass());
|
||||
// Convert any operations on primitive types as soon as possible. Unlike
|
||||
// tensor compute ops, we don't need to wait for dtype/shape inference.
|
||||
pm.addNestedPass<FuncOp>(createConvertATenToStdPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// RUN: npcomp-opt <%s -convert-aten-to-std | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @aten.dim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!numpy.any_dtype>) -> i64 {
|
||||
// CHECK: %[[RANK_INDEX:.*]] = rank %[[ARG0]] : tensor<*x!numpy.any_dtype>
|
||||
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK_INDEX]] : index to i64
|
||||
// CHECK: return %[[RANK_I64]] : i64
|
||||
func @aten.dim(%arg0: tensor<*x!numpy.any_dtype>) -> i64 {
|
||||
%0 = "aten.dim"(%arg0) : (tensor<*x!numpy.any_dtype>) -> i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.ne.int(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i64,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType {
|
||||
// CHECK: %[[I1:.*]] = cmpi ne, %[[ARG0]], %[[ARG1]] : i64
|
||||
// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType
|
||||
// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType
|
||||
func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
|
||||
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: npcomp-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @aten.__is__
|
||||
// CHECK: %[[FALSE:.*]] = basicpy.bool_constant false
|
||||
// CHECK: return %[[FALSE]] : !basicpy.BoolType
|
||||
func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
|
||||
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
Loading…
Reference in New Issue