Add recognition/folder/lowering for aten::__is__, aten::ne.int, and aten::dim

Interestingly, TorchScript has its own op (`torch::jit::Operator`)
registry separate from the dispatcher (it is a superset of the
dispatcher).

This is where the "prim" ops and some "aten" ops (that should probably
be renamed to "prim") live. In particular, `aten::__is__` is in that
latter category of "aten but really prim". This registry is also the
source of truth for what the TorchScript interpreter calls into when it
executes.

The bulk of the "not part of the dispatcher" ops live in
09feb5f579/torch/csrc/jit/runtime/register_prim_ops.cpp (L82)

And the registry itself lives in:
09feb5f579/torch/csrc/jit/runtime/operator.cpp (L196)

This fold further reduces the IR of ResNet by folding away some
more not-taken branches. These not-taken branches in ResNet require
first-class handling of the list type which we don't yet have on any
backend.
pull/213/head
Sean Silva 2021-04-27 15:15:50 -07:00
parent 7eb36b4ae7
commit 55c3cc6624
18 changed files with 427 additions and 90 deletions

View File

@ -23,14 +23,23 @@ namespace {
static const char kGetRegisteredOpsDocstring[] =
R"(Gets a data structure of all registered ops.
The returned data reflects the metadata available in the c10 dispatcher at
the time of this call. It is meant for various code generation tools.
The returned data reflects the metadata available in the Torch JIT's
registry at the time of this call. It includes both the operators available
in the c10 dispatcher and an auxiliary set of operators that the Torch JIT
uses to implement auxiliary operations that in the non-TorchScript case
are performed by Python itself.
This information is meant for various code generation tools.
Returns:
A list of records, one for each op. Each record is a dict of the following:
A list of records, one for each `torch::jit::Operator`. Known to the
Torch JIT operator registry. Each record is a dict of the following:
"name": tuple -> (qualified_name, overload)
"is_c10_op": bool -> Whether the op is in the c10 dispatcher registry,
or is a JIT-only op.
"is_vararg": bool -> Whether the op accepts variable arguments
"is_varret": bool -> Whether the op produces variable returns
"is_mutable": bool -> Whether the op potentially mutates any operand
"arguments" and "returns": List[Dict] -> Having keys:
"type": str -> PyTorch type name as in op signatures
"pytype": str -> PyType style type annotation
@ -39,92 +48,77 @@ Returns:
"alias_info": Dict -> Alias info with keys "before" and "after"
)";
class LambdaOpRegistrationListener : public c10::OpRegistrationListener {
public:
using CallbackTy = std::function<void(const c10::OperatorHandle &)>;
LambdaOpRegistrationListener(CallbackTy callback)
: callback(std::move(callback)) {}
void onOperatorRegistered(const c10::OperatorHandle &op) override {
callback(op);
}
void onOperatorDeregistered(const c10::OperatorHandle &op) override {}
private:
CallbackTy callback;
};
py::list GetRegisteredOps() {
py::list results;
c10::Dispatcher &dispatcher = c10::Dispatcher::singleton();
auto listener = std::make_unique<LambdaOpRegistrationListener>(
[&](const c10::OperatorHandle &op) -> void {
if (!op.hasSchema()) {
// Legacy?
return;
}
py::dict record;
{
py::tuple name(2);
name[0] = op.operator_name().name;
name[1] = op.operator_name().overload_name;
record["name"] = std::move(name);
}
// Walk the JIT operator registry to find all the ops that we might need
// for introspection / ODS generation.
// This registry contains a superset of the ops available to the dispatcher,
// since the JIT has its own dispatch mechanism that it uses to implement
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
// as `aten::__is__`.
for (const std::shared_ptr<torch::jit::Operator> &op :
torch::jit::getAllOperators()) {
const c10::FunctionSchema &schema = op->schema();
auto &schema = op.schema();
record["is_vararg"] = schema.is_vararg();
record["is_varret"] = schema.is_varret();
record["is_mutable"] = schema.is_mutable();
py::dict record;
{
py::tuple name(2);
name[0] = schema.name();
name[1] = schema.overload_name();
record["name"] = std::move(name);
}
py::list arguments;
py::list returns;
auto addArgument = [](py::list &container, const c10::Argument &arg) {
py::dict argRecord;
argRecord["name"] = arg.name();
argRecord["type"] = arg.type()->str();
argRecord["pytype"] = arg.type()->annotation_str();
if (arg.N())
argRecord["N"] = *arg.N();
// TODO: If the default value becomes useful, switch on it and return
// a real value, not just a string print.
if (arg.default_value()) {
std::stringstream sout;
sout << *arg.default_value();
argRecord["default_debug"] = sout.str();
}
if (arg.alias_info()) {
py::dict aliasInfo;
py::list before;
py::list after;
for (auto &symbol : arg.alias_info()->beforeSets()) {
before.append(std::string(symbol.toQualString()));
}
for (auto &symbol : arg.alias_info()->afterSets()) {
after.append(std::string(symbol.toQualString()));
}
aliasInfo["is_write"] = arg.alias_info()->isWrite();
aliasInfo["before"] = std::move(before);
aliasInfo["after"] = std::move(after);
argRecord["alias_info"] = std::move(aliasInfo);
}
record["is_c10_op"] = op->isC10Op();
record["is_vararg"] = schema.is_vararg();
record["is_varret"] = schema.is_varret();
record["is_mutable"] = schema.is_mutable();
container.append(std::move(argRecord));
};
for (auto &argument : schema.arguments()) {
addArgument(arguments, argument);
py::list arguments;
py::list returns;
auto addArgument = [](py::list &container, const c10::Argument &arg) {
py::dict argRecord;
argRecord["name"] = arg.name();
argRecord["type"] = arg.type()->str();
argRecord["pytype"] = arg.type()->annotation_str();
if (arg.N())
argRecord["N"] = *arg.N();
// TODO: If the default value becomes useful, switch on it and return
// a real value, not just a string print.
if (arg.default_value()) {
std::stringstream sout;
sout << *arg.default_value();
argRecord["default_debug"] = sout.str();
}
if (arg.alias_info()) {
py::dict aliasInfo;
py::list before;
py::list after;
for (auto &symbol : arg.alias_info()->beforeSets()) {
before.append(std::string(symbol.toQualString()));
}
for (auto &returnArg : schema.returns()) {
addArgument(returns, returnArg);
for (auto &symbol : arg.alias_info()->afterSets()) {
after.append(std::string(symbol.toQualString()));
}
record["arguments"] = std::move(arguments);
record["returns"] = std::move(returns);
results.append(std::move(record));
});
// Note: addRegistrationListener reports all currently registered ops
// during the call and then incrementally reports newer ops until the RAII
// return value is destroyed. Since we only want the current, surround in
// a block so it immediately unregisters.
{ dispatcher.addRegistrationListener(std::move(listener)); }
aliasInfo["is_write"] = arg.alias_info()->isWrite();
aliasInfo["before"] = std::move(before);
aliasInfo["after"] = std::move(after);
argRecord["alias_info"] = std::move(aliasInfo);
}
container.append(std::move(argRecord));
};
for (auto &argument : schema.arguments()) {
addArgument(arguments, argument);
}
for (auto &returnArg : schema.returns()) {
addArgument(returns, returnArg);
}
record["arguments"] = std::move(arguments);
record["returns"] = std::move(returns);
results.append(std::move(record));
}
return results;
}

View File

@ -92,6 +92,12 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
f"{snakecase_to_camelcase(uname)}Op", uname)
g.print_banner("TorchScript primitive ops")
g.ordinary_primitive_op("aten::__is__(t1,t2)", "IsOp", "__is__",
has_folder=True)
g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim")
g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int")
# Convolution ops. Note that these are special in PyTorch and the importer,
# and we model them after the signatures of the convolution_overrideable
# ops (generic for non-CPU/GPU backends) but set the names according to
@ -274,6 +280,37 @@ class OpGenerator:
)
opdef.emit()
def ordinary_primitive_op(self,
kernel_sig: str,
ods_name: str,
op_name: str,
traits: Sequence[str] = (),
**kwargs):
""""An ordinary op which might operate on a variety of non-tensor types."""
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.transforms(
type_transforms={
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
"t1": "AnyTorchType",
"t2": "AnyTorchType",
},
flag_transforms={
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
},
)
opdef.emit()
def ordinary_inplace_op(self, kernel_sig: str, ods_name: str, op_name: str,
**kwargs):
"""In-place ops (ending in '_').
@ -452,7 +489,8 @@ class InflightOpDef:
override_arg_types: Sequence[str] = None,
override_return_types: Sequence[str] = None,
drop_arg_indices: Sequence[int] = (),
drop_return_indices: Sequence[int] = ()):
drop_return_indices: Sequence[int] = (),
has_folder: bool = False):
super().__init__()
self.g = g
self.kernel_sig = kernel_sig
@ -466,6 +504,7 @@ class InflightOpDef:
self.override_return_types = override_return_types
self.drop_arg_indices = drop_arg_indices
self.drop_return_indices = drop_return_indices
self.has_folder = has_folder
self.reg_record = g.get_reg_record(self.kernel_sig)
self._emitted = False
self._traceback = traceback.extract_stack()[0:-2]
@ -548,7 +587,8 @@ class InflightOpDef:
self.reg_record,
ods_ins=self.ods_ins,
ods_outs=self.ods_outs,
traits=self.traits)
traits=self.traits,
has_folder=self.has_folder)
self.g.impl_emitter.emit_kernel_methods(
self.ods_name,
self.reg_record,
@ -608,7 +648,8 @@ class OdsEmitter(EmitterBase):
ods_ins: List[Tuple[str, str]],
ods_outs: List[Tuple[str, str]],
traits: Sequence[str] = (),
summary: Optional[str] = None):
summary: Optional[str] = None,
has_folder: bool = False):
# Def first-line.
full_traits = list(traits)
full_traits.append(
@ -636,6 +677,9 @@ class OdsEmitter(EmitterBase):
self._emit_dag_list_body(ods_outs)
self.print(");")
if has_folder:
self.print("let hasFolder = 1;")
# Def last-line.
self.print("}\n")

View File

@ -7,5 +7,5 @@ import _torch_mlir
# This check is just for a built-in op that is unlikely to change (and is
# otherwise insignificant).
# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]}
# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_c10_op': True, 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]}
print('\n\n'.join([repr(r) for r in _torch_mlir.get_registered_ops()]))

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H
#define NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToStdPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H

View File

@ -20,6 +20,11 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
}
def ConvertATenToStd : Pass<"convert-aten-to-std", "FuncOp"> {
let summary = "Convert recognized ATen to Std ops";
let constructor = "mlir::NPCOMP::createConvertATenToStdPass()";
}
def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> {
let summary = "Convert recognized ATen to Linalg ops";
let description = [{

View File

@ -23,6 +23,7 @@ include "mlir/IR/OpBase.td"
def ATen_Dialect : Dialect {
let name = "aten";
let cppNamespace = "::mlir::NPCOMP::aten";
let hasConstantMaterializer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -868,6 +868,64 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
return metadata;
}
// -----------------------------------------------------------------------------
// TorchScript primitive ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata IsOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &IsOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::__is__";
m.addArgTypes({"t1", "t2"});
m.addArgConversions({KVC::kNone, KVC::kNone});
m.addReturnTypes({"bool"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
Torch::KernelMetadata DimOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &DimOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::dim";
m.addArgTypes({"Tensor"});
m.addArgConversions({KVC::kImmutableTensor});
m.addReturnTypes({"int"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
Torch::KernelMetadata NeIntOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::ne";
m.addArgTypes({"int", "int"});
m.addArgConversions({KVC::kNone, KVC::kNone});
m.addReturnTypes({"bool"});
m.addReturnConversions({KVC::kNone});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------

View File

@ -472,6 +472,43 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
);
}
// -----------------------------------------------------------------------------
// TorchScript primitive ops
// -----------------------------------------------------------------------------
def aten_IsOp: aten_Op<"__is__", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::__is__";
let arguments = (ins
AnyTorchType:$self,
AnyTorchType:$obj
);
let results = (outs
AnyTorchBoolType
);
let hasFolder = 1;
}
def aten_DimOp: aten_Op<"dim", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::dim";
let arguments = (ins
AnyTorchImmutableTensor:$self
);
let results = (outs
AnyTorchIntType
);
}
def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::ne";
let arguments = (ins
AnyTorchIntType:$a,
AnyTorchIntType:$b
);
let results = (outs
AnyTorchBoolType
);
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------

View File

@ -0,0 +1,71 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
#include "../PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
// -----------------------------------------------------------------------------
// This is going to eventually be O(#aten ops), which is in the 100s.
// Note: Confusingly, ATen's "dim" means "number of dimensions" which is what
// MLIR calls "rank".
LogicalResult convertDimOp(aten::DimOp op, PatternRewriter &rewriter) {
if (!op.getOperand().getType().isa<TensorType>())
return rewriter.notifyMatchFailure(op, "must be tensor only");
auto rank = rewriter.create<RankOp>(op->getLoc(), op.getOperand());
rewriter.replaceOpWithNewOp<IndexCastOp>(op, op.getType(), rank);
return success();
}
LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) {
auto i1 = rewriter.create<CmpIOp>(op->getLoc(), CmpIPredicate::ne,
op->getOperand(0), op->getOperand(1));
rewriter.replaceOpWithNewOp<Basicpy::BoolCastOp>(op, op.getType(), i1);
return success();
}
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertATenToStd : public ConvertATenToStdBase<ConvertATenToStd> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
}
void runOnOperation() override {
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
}
FrozenRewritePatternSet getPatterns() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add(convertDimOp);
patterns.add(convertNeIntOp);
return std::move(patterns);
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertATenToStdPass() {
return std::make_unique<ConvertATenToStd>();
}

View File

@ -0,0 +1,18 @@
add_npcomp_conversion_library(NPCOMPATenToStd
ATenToStd.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToStd
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRStandard
NPCOMPATenDialect
)

View File

@ -1,4 +1,5 @@
add_subdirectory(ATenToLinalg)
add_subdirectory(ATenToStd)
add_subdirectory(ATenToTCF)
add_subdirectory(BasicpyToStd)
add_subdirectory(NumpyToTCF)

View File

@ -9,6 +9,7 @@
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
#include "npcomp/Conversion/NumpyToTCF/Passes.h"

View File

@ -10,7 +10,9 @@
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
@ -96,17 +98,25 @@ void ATenDialect::printType(mlir::Type type, DialectAsmPrinter &os) const {
} // namespace NPCOMP
} // namespace mlir
Operation *ATenDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// Bool (i1 -> !basicpy.BoolType).
if (type.isa<Basicpy::BoolType>()) {
auto i1Value = value.dyn_cast<IntegerAttr>();
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
return builder.create<Basicpy::BoolConstantOp>(loc, type, i1Value);
}
return nullptr;
}
void ATenDialect::initialize() {
addTypes<ATenListType>();
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
>();
getContext()->getOrLoadDialect("torch");
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"

View File

@ -0,0 +1,40 @@
//===- ATenDialect.cpp ------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::aten;
//===----------------------------------------------------------------------===//
// IsOp
//===----------------------------------------------------------------------===//
OpFoldResult IsOp::fold(ArrayRef<Attribute> operands) {
auto lhsType = self().getType();
auto rhsType = obj().getType();
// If either type is a NoneType, make it be the lhsType.
if (rhsType.isa<Basicpy::NoneType>())
std::swap(lhsType, rhsType);
// TODO: Implement and use subtype infra for this.
// If neither type is a subtype of the other, then the result is false.
if (lhsType.isa<Basicpy::NoneType>() && !rhsType.isa<Torch::OptionalType>())
return IntegerAttr::get(IntegerType::get(getContext(), 1), 0);
return nullptr;
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc"

View File

@ -1,5 +1,6 @@
add_npcomp_dialect_library(NPCOMPATenDialect
ATenDialect.cpp
ATenOps.cpp
ATenDialectOpStats.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
#include "npcomp/Conversion/ATenToStd/ATenToStd.h"
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
@ -106,6 +107,9 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// Recognize ATen kernels. This is a totally local transformation that
// we want to run as soon as possible.
pm.addNestedPass<FuncOp>(aten::createRecognizeKernelsPass());
// Convert any operations on primitive types as soon as possible. Unlike
// tensor compute ops, we don't need to wait for dtype/shape inference.
pm.addNestedPass<FuncOp>(createConvertATenToStdPass());
if (options.optimize) {
// OPT-ONLY: Right now we rely on this to eliminate certain branches that

View File

@ -0,0 +1,22 @@
// RUN: npcomp-opt <%s -convert-aten-to-std | FileCheck %s
// CHECK-LABEL: func @aten.dim(
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!numpy.any_dtype>) -> i64 {
// CHECK: %[[RANK_INDEX:.*]] = rank %[[ARG0]] : tensor<*x!numpy.any_dtype>
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK_INDEX]] : index to i64
// CHECK: return %[[RANK_I64]] : i64
func @aten.dim(%arg0: tensor<*x!numpy.any_dtype>) -> i64 {
%0 = "aten.dim"(%arg0) : (tensor<*x!numpy.any_dtype>) -> i64
return %0 : i64
}
// CHECK-LABEL: func @aten.ne.int(
// CHECK-SAME: %[[ARG0:.*]]: i64,
// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType {
// CHECK: %[[I1:.*]] = cmpi ne, %[[ARG0]], %[[ARG1]] : i64
// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType
// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType
func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType {
%0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}

View File

@ -0,0 +1,9 @@
// RUN: npcomp-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: func @aten.__is__
// CHECK: %[[FALSE:.*]] = basicpy.bool_constant false
// CHECK: return %[[FALSE]] : !basicpy.BoolType
func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
%0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType
return %0 : !basicpy.BoolType
}