From 55c3cc6624f90308803495c7fea5bf36a685384c Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 27 Apr 2021 15:15:50 -0700 Subject: [PATCH] Add recognition/folder/lowering for aten::__is__, aten::ne.int, and aten::dim Interestingly, TorchScript has its own op (`torch::jit::Operator`) registry separate from the dispatcher (it is a superset of the dispatcher). This is where the "prim" ops and some "aten" ops (that should probably be renamed to "prim") live. In particular, `aten::__is__` is in that latter category of "aten but really prim". This registry is also the source of truth for what the TorchScript interpreter calls into when it executes. The bulk of the "not part of the dispatcher" ops live in https://github.com/pytorch/pytorch/blob/09feb5f579d40b550559360870907963e20c6954/torch/csrc/jit/runtime/register_prim_ops.cpp#L82 And the registry itself lives in: https://github.com/pytorch/pytorch/blob/09feb5f579d40b550559360870907963e20c6954/torch/csrc/jit/runtime/operator.cpp#L196 This fold further reduces the IR of ResNet by folding away some more not-taken branches. These not-taken branches in ResNet require first-class handling of the list type which we don't yet have on any backend. --- .../pytorch/csrc/builder/python_bindings.cpp | 156 +++++++++--------- .../codegen/torch_signature_ods_gen.py | 50 +++++- .../test/builder/get_registered_ops.py | 2 +- .../npcomp/Conversion/ATenToStd/ATenToStd.h | 21 +++ include/npcomp/Conversion/Passes.td | 5 + include/npcomp/Dialect/ATen/IR/ATenDialect.td | 1 + .../Dialect/ATen/IR/GeneratedATenOps.cpp.inc | 58 +++++++ .../Dialect/ATen/IR/GeneratedATenOps.td | 37 +++++ lib/Conversion/ATenToStd/ATenToStd.cpp | 71 ++++++++ lib/Conversion/ATenToStd/CMakeLists.txt | 18 ++ lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/Passes.cpp | 1 + lib/Dialect/ATen/IR/ATenDialect.cpp | 20 ++- lib/Dialect/ATen/IR/ATenOps.cpp | 40 +++++ lib/Dialect/ATen/IR/CMakeLists.txt | 1 + lib/Dialect/Torch/Transforms/Passes.cpp | 4 + test/Conversion/ATenToStd/basic.mlir | 22 +++ test/Dialect/ATen/canonicalize.mlir | 9 + 18 files changed, 427 insertions(+), 90 deletions(-) create mode 100644 include/npcomp/Conversion/ATenToStd/ATenToStd.h create mode 100644 lib/Conversion/ATenToStd/ATenToStd.cpp create mode 100644 lib/Conversion/ATenToStd/CMakeLists.txt create mode 100644 lib/Dialect/ATen/IR/ATenOps.cpp create mode 100644 test/Conversion/ATenToStd/basic.mlir create mode 100644 test/Dialect/ATen/canonicalize.mlir diff --git a/frontends/pytorch/csrc/builder/python_bindings.cpp b/frontends/pytorch/csrc/builder/python_bindings.cpp index 93e507958..8053b5c59 100644 --- a/frontends/pytorch/csrc/builder/python_bindings.cpp +++ b/frontends/pytorch/csrc/builder/python_bindings.cpp @@ -23,14 +23,23 @@ namespace { static const char kGetRegisteredOpsDocstring[] = R"(Gets a data structure of all registered ops. -The returned data reflects the metadata available in the c10 dispatcher at -the time of this call. It is meant for various code generation tools. +The returned data reflects the metadata available in the Torch JIT's +registry at the time of this call. It includes both the operators available +in the c10 dispatcher and an auxiliary set of operators that the Torch JIT +uses to implement auxiliary operations that in the non-TorchScript case +are performed by Python itself. + +This information is meant for various code generation tools. Returns: - A list of records, one for each op. Each record is a dict of the following: + A list of records, one for each `torch::jit::Operator`. Known to the + Torch JIT operator registry. Each record is a dict of the following: "name": tuple -> (qualified_name, overload) + "is_c10_op": bool -> Whether the op is in the c10 dispatcher registry, + or is a JIT-only op. "is_vararg": bool -> Whether the op accepts variable arguments "is_varret": bool -> Whether the op produces variable returns + "is_mutable": bool -> Whether the op potentially mutates any operand "arguments" and "returns": List[Dict] -> Having keys: "type": str -> PyTorch type name as in op signatures "pytype": str -> PyType style type annotation @@ -39,92 +48,77 @@ Returns: "alias_info": Dict -> Alias info with keys "before" and "after" )"; -class LambdaOpRegistrationListener : public c10::OpRegistrationListener { -public: - using CallbackTy = std::function; - LambdaOpRegistrationListener(CallbackTy callback) - : callback(std::move(callback)) {} - void onOperatorRegistered(const c10::OperatorHandle &op) override { - callback(op); - } - void onOperatorDeregistered(const c10::OperatorHandle &op) override {} - -private: - CallbackTy callback; -}; - py::list GetRegisteredOps() { py::list results; - c10::Dispatcher &dispatcher = c10::Dispatcher::singleton(); - auto listener = std::make_unique( - [&](const c10::OperatorHandle &op) -> void { - if (!op.hasSchema()) { - // Legacy? - return; - } - py::dict record; - { - py::tuple name(2); - name[0] = op.operator_name().name; - name[1] = op.operator_name().overload_name; - record["name"] = std::move(name); - } + // Walk the JIT operator registry to find all the ops that we might need + // for introspection / ODS generation. + // This registry contains a superset of the ops available to the dispatcher, + // since the JIT has its own dispatch mechanism that it uses to implement + // "prim" ops and a handful of "aten" ops that are effectively prim ops, such + // as `aten::__is__`. + for (const std::shared_ptr &op : + torch::jit::getAllOperators()) { + const c10::FunctionSchema &schema = op->schema(); - auto &schema = op.schema(); - record["is_vararg"] = schema.is_vararg(); - record["is_varret"] = schema.is_varret(); - record["is_mutable"] = schema.is_mutable(); + py::dict record; + { + py::tuple name(2); + name[0] = schema.name(); + name[1] = schema.overload_name(); + record["name"] = std::move(name); + } - py::list arguments; - py::list returns; - auto addArgument = [](py::list &container, const c10::Argument &arg) { - py::dict argRecord; - argRecord["name"] = arg.name(); - argRecord["type"] = arg.type()->str(); - argRecord["pytype"] = arg.type()->annotation_str(); - if (arg.N()) - argRecord["N"] = *arg.N(); - // TODO: If the default value becomes useful, switch on it and return - // a real value, not just a string print. - if (arg.default_value()) { - std::stringstream sout; - sout << *arg.default_value(); - argRecord["default_debug"] = sout.str(); - } - if (arg.alias_info()) { - py::dict aliasInfo; - py::list before; - py::list after; - for (auto &symbol : arg.alias_info()->beforeSets()) { - before.append(std::string(symbol.toQualString())); - } - for (auto &symbol : arg.alias_info()->afterSets()) { - after.append(std::string(symbol.toQualString())); - } - aliasInfo["is_write"] = arg.alias_info()->isWrite(); - aliasInfo["before"] = std::move(before); - aliasInfo["after"] = std::move(after); - argRecord["alias_info"] = std::move(aliasInfo); - } + record["is_c10_op"] = op->isC10Op(); + record["is_vararg"] = schema.is_vararg(); + record["is_varret"] = schema.is_varret(); + record["is_mutable"] = schema.is_mutable(); - container.append(std::move(argRecord)); - }; - for (auto &argument : schema.arguments()) { - addArgument(arguments, argument); + py::list arguments; + py::list returns; + auto addArgument = [](py::list &container, const c10::Argument &arg) { + py::dict argRecord; + argRecord["name"] = arg.name(); + argRecord["type"] = arg.type()->str(); + argRecord["pytype"] = arg.type()->annotation_str(); + if (arg.N()) + argRecord["N"] = *arg.N(); + // TODO: If the default value becomes useful, switch on it and return + // a real value, not just a string print. + if (arg.default_value()) { + std::stringstream sout; + sout << *arg.default_value(); + argRecord["default_debug"] = sout.str(); + } + if (arg.alias_info()) { + py::dict aliasInfo; + py::list before; + py::list after; + for (auto &symbol : arg.alias_info()->beforeSets()) { + before.append(std::string(symbol.toQualString())); } - for (auto &returnArg : schema.returns()) { - addArgument(returns, returnArg); + for (auto &symbol : arg.alias_info()->afterSets()) { + after.append(std::string(symbol.toQualString())); } - record["arguments"] = std::move(arguments); - record["returns"] = std::move(returns); - results.append(std::move(record)); - }); - // Note: addRegistrationListener reports all currently registered ops - // during the call and then incrementally reports newer ops until the RAII - // return value is destroyed. Since we only want the current, surround in - // a block so it immediately unregisters. - { dispatcher.addRegistrationListener(std::move(listener)); } + aliasInfo["is_write"] = arg.alias_info()->isWrite(); + aliasInfo["before"] = std::move(before); + aliasInfo["after"] = std::move(after); + argRecord["alias_info"] = std::move(aliasInfo); + } + + container.append(std::move(argRecord)); + }; + for (auto &argument : schema.arguments()) { + addArgument(arguments, argument); + } + for (auto &returnArg : schema.returns()) { + addArgument(returns, returnArg); + } + record["arguments"] = std::move(arguments); + record["returns"] = std::move(returns); + results.append(std::move(record)); + } + return results; } diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py index 88d28abfb..1211279d5 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -92,6 +92,12 @@ def generate_ops(g: "OpGenerator"): g.ordinary_unary_op(f"aten::{uname}(Tensor)", f"{snakecase_to_camelcase(uname)}Op", uname) + g.print_banner("TorchScript primitive ops") + g.ordinary_primitive_op("aten::__is__(t1,t2)", "IsOp", "__is__", + has_folder=True) + g.ordinary_primitive_op("aten::dim(Tensor)", "DimOp", "dim") + g.ordinary_primitive_op("aten::ne(int,int)", "NeIntOp", "ne.int") + # Convolution ops. Note that these are special in PyTorch and the importer, # and we model them after the signatures of the convolution_overrideable # ops (generic for non-CPU/GPU backends) but set the names according to @@ -274,6 +280,37 @@ class OpGenerator: ) opdef.emit() + def ordinary_primitive_op(self, + kernel_sig: str, + ods_name: str, + op_name: str, + traits: Sequence[str] = (), + **kwargs): + """"An ordinary op which might operate on a variety of non-tensor types.""" + opdef = self.define_op( + kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + traits=list(traits) + ["NoSideEffect"], + **kwargs) + opdef.transforms( + type_transforms={ + "Tensor": "AnyTorchImmutableTensor", + "Tensor?": "AnyTorchOptionalImmutableTensor", + "int": "AnyTorchIntType", + "int[]": "AnyTorchIntListType", + "bool": "AnyTorchBoolType", + "bool[]": "AnyTorchBoolListType", + "t1": "AnyTorchType", + "t2": "AnyTorchType", + }, + flag_transforms={ + "Tensor": ["kImmutableTensor"], + "Tensor?": ["kImmutableTensor"], + }, + ) + opdef.emit() + def ordinary_inplace_op(self, kernel_sig: str, ods_name: str, op_name: str, **kwargs): """In-place ops (ending in '_'). @@ -452,7 +489,8 @@ class InflightOpDef: override_arg_types: Sequence[str] = None, override_return_types: Sequence[str] = None, drop_arg_indices: Sequence[int] = (), - drop_return_indices: Sequence[int] = ()): + drop_return_indices: Sequence[int] = (), + has_folder: bool = False): super().__init__() self.g = g self.kernel_sig = kernel_sig @@ -466,6 +504,7 @@ class InflightOpDef: self.override_return_types = override_return_types self.drop_arg_indices = drop_arg_indices self.drop_return_indices = drop_return_indices + self.has_folder = has_folder self.reg_record = g.get_reg_record(self.kernel_sig) self._emitted = False self._traceback = traceback.extract_stack()[0:-2] @@ -548,7 +587,8 @@ class InflightOpDef: self.reg_record, ods_ins=self.ods_ins, ods_outs=self.ods_outs, - traits=self.traits) + traits=self.traits, + has_folder=self.has_folder) self.g.impl_emitter.emit_kernel_methods( self.ods_name, self.reg_record, @@ -608,7 +648,8 @@ class OdsEmitter(EmitterBase): ods_ins: List[Tuple[str, str]], ods_outs: List[Tuple[str, str]], traits: Sequence[str] = (), - summary: Optional[str] = None): + summary: Optional[str] = None, + has_folder: bool = False): # Def first-line. full_traits = list(traits) full_traits.append( @@ -636,6 +677,9 @@ class OdsEmitter(EmitterBase): self._emit_dag_list_body(ods_outs) self.print(");") + if has_folder: + self.print("let hasFolder = 1;") + # Def last-line. self.print("}\n") diff --git a/frontends/pytorch/test/builder/get_registered_ops.py b/frontends/pytorch/test/builder/get_registered_ops.py index 6f85723cc..f33874ad2 100644 --- a/frontends/pytorch/test/builder/get_registered_ops.py +++ b/frontends/pytorch/test/builder/get_registered_ops.py @@ -7,5 +7,5 @@ import _torch_mlir # This check is just for a built-in op that is unlikely to change (and is # otherwise insignificant). -# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]} +# CHECK: {'name': ('aten::mul', 'Tensor'), 'is_c10_op': True, 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]} print('\n\n'.join([repr(r) for r in _torch_mlir.get_registered_ops()])) diff --git a/include/npcomp/Conversion/ATenToStd/ATenToStd.h b/include/npcomp/Conversion/ATenToStd/ATenToStd.h new file mode 100644 index 000000000..cad34b872 --- /dev/null +++ b/include/npcomp/Conversion/ATenToStd/ATenToStd.h @@ -0,0 +1,21 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H +#define NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace NPCOMP { +std::unique_ptr> createConvertATenToStdPass(); +} +} // namespace mlir + +#endif // NPCOMP_CONVERSION_ATENTOSTD_ATENTOSTD_H diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index 5c0c7dbf0..8b74f5d2e 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -20,6 +20,11 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> { let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()"; } +def ConvertATenToStd : Pass<"convert-aten-to-std", "FuncOp"> { + let summary = "Convert recognized ATen to Std ops"; + let constructor = "mlir::NPCOMP::createConvertATenToStdPass()"; +} + def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> { let summary = "Convert recognized ATen to Linalg ops"; let description = [{ diff --git a/include/npcomp/Dialect/ATen/IR/ATenDialect.td b/include/npcomp/Dialect/ATen/IR/ATenDialect.td index 50fa167be..c81b2ed58 100644 --- a/include/npcomp/Dialect/ATen/IR/ATenDialect.td +++ b/include/npcomp/Dialect/ATen/IR/ATenDialect.td @@ -23,6 +23,7 @@ include "mlir/IR/OpBase.td" def ATen_Dialect : Dialect { let name = "aten"; let cppNamespace = "::mlir::NPCOMP::aten"; + let hasConstantMaterializer = 1; } //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index f9d56ce30..2215e050d 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -868,6 +868,64 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() { return metadata; } +// ----------------------------------------------------------------------------- +// TorchScript primitive ops +// ----------------------------------------------------------------------------- + +Torch::KernelMetadata IsOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &IsOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::__is__"; + m.addArgTypes({"t1", "t2"}); + m.addArgConversions({KVC::kNone, KVC::kNone}); + m.addReturnTypes({"bool"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata DimOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &DimOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::dim"; + m.addArgTypes({"Tensor"}); + m.addArgConversions({KVC::kImmutableTensor}); + m.addReturnTypes({"int"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata NeIntOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NeIntOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::ne"; + m.addArgTypes({"int", "int"}); + m.addArgConversions({KVC::kNone, KVC::kNone}); + m.addReturnTypes({"bool"}); + m.addReturnConversions({KVC::kNone}); + return m; + })(); + return metadata; +} + // ----------------------------------------------------------------------------- // NN ops // ----------------------------------------------------------------------------- diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index 98baf9ce2..ac2a6b23e 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -472,6 +472,43 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::__is__"; + let arguments = (ins + AnyTorchType:$self, + AnyTorchType:$obj + ); + let results = (outs + AnyTorchBoolType + ); + let hasFolder = 1; +} + +def aten_DimOp: aten_Op<"dim", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::dim"; + let arguments = (ins + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchIntType + ); +} + +def aten_NeIntOp: aten_Op<"ne.int", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::ne"; + let arguments = (ins + AnyTorchIntType:$a, + AnyTorchIntType:$b + ); + let results = (outs + AnyTorchBoolType + ); +} + // ----------------------------------------------------------------------------- // NN ops // ----------------------------------------------------------------------------- diff --git a/lib/Conversion/ATenToStd/ATenToStd.cpp b/lib/Conversion/ATenToStd/ATenToStd.cpp new file mode 100644 index 000000000..0c41ec6ac --- /dev/null +++ b/lib/Conversion/ATenToStd/ATenToStd.cpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/ATenToStd/ATenToStd.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Dialect/ATen/IR/ATenDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +// ----------------------------------------------------------------------------- +// Patterns (as this grows, it should be organized into multiple files) +// ----------------------------------------------------------------------------- +// This is going to eventually be O(#aten ops), which is in the 100s. + +// Note: Confusingly, ATen's "dim" means "number of dimensions" which is what +// MLIR calls "rank". +LogicalResult convertDimOp(aten::DimOp op, PatternRewriter &rewriter) { + if (!op.getOperand().getType().isa()) + return rewriter.notifyMatchFailure(op, "must be tensor only"); + auto rank = rewriter.create(op->getLoc(), op.getOperand()); + rewriter.replaceOpWithNewOp(op, op.getType(), rank); + return success(); +} + +LogicalResult convertNeIntOp(aten::NeIntOp op, PatternRewriter &rewriter) { + auto i1 = rewriter.create(op->getLoc(), CmpIPredicate::ne, + op->getOperand(0), op->getOperand(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), i1); + return success(); +} + +// ----------------------------------------------------------------------------- +// The pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertATenToStd : public ConvertATenToStdBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); + } + + FrozenRewritePatternSet getPatterns() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(convertDimOp); + patterns.add(convertNeIntOp); + return std::move(patterns); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createConvertATenToStdPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/ATenToStd/CMakeLists.txt b/lib/Conversion/ATenToStd/CMakeLists.txt new file mode 100644 index 000000000..4dc920fb7 --- /dev/null +++ b/lib/Conversion/ATenToStd/CMakeLists.txt @@ -0,0 +1,18 @@ +add_npcomp_conversion_library(NPCOMPATenToStd + ATenToStd.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToStd + + DEPENDS + NPCOMPConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRStandard + NPCOMPATenDialect +) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 4fda3d8e2..db6d8c1a9 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ATenToLinalg) +add_subdirectory(ATenToStd) add_subdirectory(ATenToTCF) add_subdirectory(BasicpyToStd) add_subdirectory(NumpyToTCF) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 1085100a4..5efbff970 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,6 +9,7 @@ #include "npcomp/Conversion/Passes.h" #include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h" +#include "npcomp/Conversion/ATenToStd/ATenToStd.h" #include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/NumpyToTCF/Passes.h" diff --git a/lib/Dialect/ATen/IR/ATenDialect.cpp b/lib/Dialect/ATen/IR/ATenDialect.cpp index 87486f1fa..b922d72d8 100644 --- a/lib/Dialect/ATen/IR/ATenDialect.cpp +++ b/lib/Dialect/ATen/IR/ATenDialect.cpp @@ -10,7 +10,9 @@ #include "mlir/IR/DialectImplementation.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -96,17 +98,25 @@ void ATenDialect::printType(mlir::Type type, DialectAsmPrinter &os) const { } // namespace NPCOMP } // namespace mlir +Operation *ATenDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + // Bool (i1 -> !basicpy.BoolType). + if (type.isa()) { + auto i1Value = value.dyn_cast(); + if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1) + return builder.create(loc, type, i1Value); + } + return nullptr; +} + void ATenDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST #include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" >(); + getContext()->getOrLoadDialect("torch"); } -#define GET_OP_CLASSES -#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" - #include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc" - -#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc" diff --git a/lib/Dialect/ATen/IR/ATenOps.cpp b/lib/Dialect/ATen/IR/ATenOps.cpp new file mode 100644 index 000000000..59e793c99 --- /dev/null +++ b/lib/Dialect/ATen/IR/ATenOps.cpp @@ -0,0 +1,40 @@ +//===- ATenDialect.cpp ------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/ATen/IR/ATenDialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::aten; + +//===----------------------------------------------------------------------===// +// IsOp +//===----------------------------------------------------------------------===// + +OpFoldResult IsOp::fold(ArrayRef operands) { + auto lhsType = self().getType(); + auto rhsType = obj().getType(); + // If either type is a NoneType, make it be the lhsType. + if (rhsType.isa()) + std::swap(lhsType, rhsType); + // TODO: Implement and use subtype infra for this. + // If neither type is a subtype of the other, then the result is false. + if (lhsType.isa() && !rhsType.isa()) + return IntegerAttr::get(IntegerType::get(getContext(), 1), 0); + return nullptr; +} + +#define GET_OP_CLASSES +#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc" + +#include "npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc" diff --git a/lib/Dialect/ATen/IR/CMakeLists.txt b/lib/Dialect/ATen/IR/CMakeLists.txt index 68f933850..f217b0013 100644 --- a/lib/Dialect/ATen/IR/CMakeLists.txt +++ b/lib/Dialect/ATen/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_npcomp_dialect_library(NPCOMPATenDialect ATenDialect.cpp + ATenOps.cpp ATenDialectOpStats.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index c2a5a8756..6a03f3787 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/Passes.h" #include "npcomp/Backend/Common/Passes.h" #include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h" +#include "npcomp/Conversion/ATenToStd/ATenToStd.h" #include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Dialect/ATen/Transforms/Passes.h" @@ -106,6 +107,9 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( // Recognize ATen kernels. This is a totally local transformation that // we want to run as soon as possible. pm.addNestedPass(aten::createRecognizeKernelsPass()); + // Convert any operations on primitive types as soon as possible. Unlike + // tensor compute ops, we don't need to wait for dtype/shape inference. + pm.addNestedPass(createConvertATenToStdPass()); if (options.optimize) { // OPT-ONLY: Right now we rely on this to eliminate certain branches that diff --git a/test/Conversion/ATenToStd/basic.mlir b/test/Conversion/ATenToStd/basic.mlir new file mode 100644 index 000000000..298b39ece --- /dev/null +++ b/test/Conversion/ATenToStd/basic.mlir @@ -0,0 +1,22 @@ +// RUN: npcomp-opt <%s -convert-aten-to-std | FileCheck %s + +// CHECK-LABEL: func @aten.dim( +// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!numpy.any_dtype>) -> i64 { +// CHECK: %[[RANK_INDEX:.*]] = rank %[[ARG0]] : tensor<*x!numpy.any_dtype> +// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK_INDEX]] : index to i64 +// CHECK: return %[[RANK_I64]] : i64 +func @aten.dim(%arg0: tensor<*x!numpy.any_dtype>) -> i64 { + %0 = "aten.dim"(%arg0) : (tensor<*x!numpy.any_dtype>) -> i64 + return %0 : i64 +} + +// CHECK-LABEL: func @aten.ne.int( +// CHECK-SAME: %[[ARG0:.*]]: i64, +// CHECK-SAME: %[[ARG1:.*]]: i64) -> !basicpy.BoolType { +// CHECK: %[[I1:.*]] = cmpi ne, %[[ARG0]], %[[ARG1]] : i64 +// CHECK: %[[BASICPY_BOOL:.*]] = basicpy.bool_cast %[[I1]] : i1 -> !basicpy.BoolType +// CHECK: return %[[BASICPY_BOOL]] : !basicpy.BoolType +func @aten.ne.int(%arg0: i64, %arg1: i64) -> !basicpy.BoolType { + %0 = "aten.ne.int"(%arg0, %arg1) : (i64, i64) -> !basicpy.BoolType + return %0 : !basicpy.BoolType +} diff --git a/test/Dialect/ATen/canonicalize.mlir b/test/Dialect/ATen/canonicalize.mlir new file mode 100644 index 000000000..5c24ef065 --- /dev/null +++ b/test/Dialect/ATen/canonicalize.mlir @@ -0,0 +1,9 @@ +// RUN: npcomp-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: func @aten.__is__ +// CHECK: %[[FALSE:.*]] = basicpy.bool_constant false +// CHECK: return %[[FALSE]] : !basicpy.BoolType +func @aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{ + %0 = "aten.__is__"(%arg0, %arg1) : (!basicpy.ListType, !basicpy.NoneType) -> !basicpy.BoolType + return %0 : !basicpy.BoolType +}