From 6c702b149fdeb29670fb652996dfe60b22bd32e9 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 3 Nov 2020 19:24:28 -0800 Subject: [PATCH] Add a number of kernels and new patterns. * convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_ * Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args. * The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen. * The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with. * More progress on #97 --- build_tools/update_aten_ods.sh | 3 +- .../csrc/c10_dispatch/func_builder.cpp | 5 +- .../codegen/torch_signature_ods_gen.py | 409 ++++++++++++++---- include/npcomp/Dialect/ATen/IR/.gitignore | 1 + include/npcomp/Dialect/ATen/IR/ATenOps.td | 49 --- .../Dialect/ATen/IR/GeneratedATenOps.cpp.inc | 222 ++++++++++ .../Dialect/ATen/IR/GeneratedATenOps.td | 144 ++++++ .../Dialect/ATen/IR/LegacyGeneratedATenOps.td | 155 ------- .../Dialect/ATen/Transforms/ATenToStd.td | 11 - .../npcomp/Dialect/Torch/IR/OpInterfaces.h | 16 +- include/npcomp/Dialect/Torch/IR/TorchBase.td | 26 ++ lib/Dialect/ATen/IR/ATenDialectOpStats.cpp | 63 --- .../ATen/Transforms/RecognizeKernelsPass.cpp | 113 ++++- test/Dialect/ATen/aten_conv2d.mlir | 25 -- test/Dialect/ATen/aten_conv2d_back.mlir | 23 - test/Dialect/ATen/lenet_fwd.mlir | 26 -- test/Dialect/ATen/recognize_aten_kernels.mlir | 83 ++++ 17 files changed, 920 insertions(+), 454 deletions(-) create mode 100644 include/npcomp/Dialect/ATen/IR/.gitignore delete mode 100644 test/Dialect/ATen/aten_conv2d.mlir delete mode 100644 test/Dialect/ATen/aten_conv2d_back.mlir delete mode 100644 test/Dialect/ATen/lenet_fwd.mlir diff --git a/build_tools/update_aten_ods.sh b/build_tools/update_aten_ods.sh index 7ffda337a..3ee8f7069 100755 --- a/build_tools/update_aten_ods.sh +++ b/build_tools/update_aten_ods.sh @@ -11,4 +11,5 @@ export PYTHONPATH="${build_dir}/python" python -m torch_mlir_utils.codegen.torch_signature_ods_gen \ --ods_td_file="${aten_dir}/GeneratedATenOps.td" \ - --ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" + --ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" \ + --debug_op_reg_file="${aten_dir}/ATenOpRegistrations.txt" diff --git a/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp b/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp index f9a982015..3e098cb50 100644 --- a/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp +++ b/frontends/pytorch/csrc/c10_dispatch/func_builder.cpp @@ -189,9 +189,8 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, return constValue; } -MlirValue -FuncBuilder::buildList(MlirLocation loc, - llvm::SmallVectorImpl &elements) { +MlirValue FuncBuilder::buildList(MlirLocation loc, + llvm::SmallVectorImpl &elements) { MlirType resultType = npcompListTypeGet(context); OperationStateHolder state{"basicpy.build_list", loc}; mlirOperationStateAddResults(state, 1, &resultType); diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py index 1e4b06349..efb25fcb3 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_signature_ods_gen.py @@ -13,6 +13,7 @@ import logging import re import sys import textwrap +import traceback # Note that this utility exists only in the c-extension. from _torch_mlir import get_registered_ops @@ -75,6 +76,55 @@ def generate_ops(g: "OpGenerator"): g.ordinary_unary_op(f"aten::{uname}(Tensor)", f"{snakecase_to_camelcase(uname)}Op", uname) + # Convolution ops. Note that these are special in PyTorch and the importer, + # and we model them after the signatures of the convolution_overrideable + # ops (generic for non-CPU/GPU backends) but set the names according to + # how they come in. + g.print_banner("NN ops") + g.ordinary_immutable_op( + "aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)", + "ConvolutionOp", + "convolution", + alias_kernel_names=["aten::convolution"]) + g.ordinary_immutable_op( + "aten::convolution_backward_overrideable(Tensor,Tensor,Tensor,int[],int[],int[],bool,int[],int,bool[])", + "ConvolutionBackwardOp", + "convolution_backward", + alias_kernel_names=["aten::convolution_backward"], + # These do return as None but are not coded optional in the registry :( + override_return_types=["Tensor?", "Tensor?", "Tensor?"]) + + g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)", + "LogSoftmaxOp", "log_softmax") + g.ordinary_immutable_op( + "aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)", + "LogSoftmaxBackwardDataOp", "log_softmax_backward_data") + + # Loss functions. + g.print_banner("Loss function ops") + g.ordinary_immutable_op( + "aten::nll_loss_forward(Tensor,Tensor,Tensor?,int,int)", + "NllLossForwardOp", "nll_loss_forward") + # Note also a grad_input 8-arg variant. + g.ordinary_immutable_op( + "aten::nll_loss_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)", + "NllLossBackwardOp", "nll_loss_backward") + + g.ordinary_immutable_op( + "aten::nll_loss2d_forward(Tensor,Tensor,Tensor?,int,int)", + "NllLoss2dForwardOp", "nll_loss2d_forward") + # Note also a grad_input 8-arg variant. + g.ordinary_immutable_op( + "aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)", + "NllLoss2dBackwardOp", "nll_loss2d_backward") + + # One-off in-place ops (note that many in-place arithmetic ops are handled + # as a transformation from their immutable forms). + g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)", + "CopyInplaceOp", + "copy.inplace", + drop_arg_indices=[2]) + def dump_registered_ops(outfile, reg_ops_dict): for k in sorted(reg_ops_dict.keys()): @@ -104,7 +154,20 @@ class OpGenerator: ) em.print("") - def ordinary_binary_op(self, kernel_sig, ods_name, op_name): + def define_op(self, kernel_sig, ods_name, op_name, **kwargs): + return InflightOpDef(self, + kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + **kwargs) + + def ordinary_binary_op(self, + kernel_sig, + ods_name, + op_name, + promote_trailing_out_tensor=True, + traits=(), + **kwargs): """"Binary"-ops. These ops typically have: - '.Tensor' variant where the second arg is a Tensor - '.Scalar' variant where the second arg is a Scalar @@ -124,75 +187,130 @@ class OpGenerator: - Setting all arguments and returns to kImmutableTensor - Enabling kPromoteScalarToTensor on the second argument. """ - reg_record = self._get_reg_record(kernel_sig) - ods_ins, arg_type_flags = self._map_sigtypes( - reg_record["arguments"], - type_transforms={ - "Tensor:0": "AnyTorchImmutableTensor", - "Tensor:1": "AnyTorchImmutableTensor", - "Scalar:1": "AnyTorchImmutableTensor", - "Scalar": "AnyTorchScalarType", - }, - flag_transforms={ - ":0": ["kImmutableTensor"], - ":1": ["kImmutableTensor", "kPromoteScalar"], - }) - ods_outs, return_type_flags = self._map_sigtypes( - reg_record["returns"], - type_transforms={ - "Tensor:0": "AnyTorchImmutableTensor", - }, - flag_transforms={ - ":0": ["kImmutableTensor"], - }) - self.ods_emitter.emit_opdef(ods_name, - op_name, - reg_record, - ods_ins=ods_ins, - ods_outs=ods_outs, - traits=["NoSideEffect"]) - self.impl_emitter.emit_kernel_methods(ods_name, - reg_record, - arg_type_flags=arg_type_flags, - return_type_flags=return_type_flags, - promote_trailing_out_tensor=True) + opdef = self.define_op( + kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + promote_trailing_out_tensor=promote_trailing_out_tensor, + traits=list(traits) + ["NoSideEffect"], + **kwargs) + opdef.arg_transforms(type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + "Tensor:1": "AnyTorchImmutableTensor", + "Scalar:1": "AnyTorchImmutableTensor", + "Scalar": "AnyTorchScalarType", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + ":1": ["kImmutableTensor", "kPromoteScalar"], + }) + opdef.return_transforms(type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + opdef.emit() - def ordinary_unary_op(self, kernel_sig, ods_name, op_name): + def ordinary_immutable_op(self, + kernel_sig, + ods_name, + op_name, + promote_trailing_out_tensor=True, + traits=(), + **kwargs): + """"An ordinary immutable-tensor based op.""" + opdef = self.define_op( + kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + promote_trailing_out_tensor=promote_trailing_out_tensor, + traits=list(traits) + ["NoSideEffect"], + **kwargs) + opdef.transforms(type_transforms={ + "Tensor": "AnyTorchImmutableTensor", + "Tensor?": "AnyTorchOptionalImmutableTensor", + "int": "AnyTorchIntType", + "int[]": "AnyTorchIntListType", + "bool": "AnyTorchBoolType", + "bool[]": "AnyTorchBoolListType", + }, + flag_transforms={ + "Tensor": ["kImmutableTensor"], + "Tensor?": ["kImmutableTensor"], + }) + opdef.emit() + + def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs): + """In-place ops (ending in '_'). + + These ops take a mutable first argument and then standard immutable + conversions for subsequent. When emitting into MLIR, the return value is + dropped. + """ + opdef = self.define_op(kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + **kwargs) + opdef.arg_transforms(type_transforms={ + ":0": "AnyTorchMutableTensor", + "Tensor": "AnyTorchImmutableTensor", + "Tensor?": "AnyTorchOptionalImmutableTensor", + "int": "AnyTorchIntType", + "int[]": "AnyTorchIntListType", + "bool": "AnyTorchBoolType", + "bool[]": "AnyTorchBoolListType", + }, + flag_transforms={ + ":0": [], + "Tensor": ["kImmutableTensor"], + "Tensor?": ["kImmutableTensor"], + }) + opdef.return_transforms( + type_transforms={ + ":0": "DROP_UNUSED", # Ignored because we clear the outs below. + }, + flag_transforms={ + ":0": ["kDropReturnAndAliasArg0"], + }) + opdef.map_signatures() + opdef.ods_outs = [] # Clear the computed outs. + opdef.emit() + + def ordinary_unary_op(self, + kernel_sig, + ods_name, + op_name, + promote_trailing_out_tensor=True, + traits=(), + **kwargs): """Unary ops. These take and return a tensor and typically have an out and inplace variant (they may not but we generate patterns to match anyway). """ - reg_record = self._get_reg_record(kernel_sig) - ods_ins, arg_type_flags = self._map_sigtypes( - reg_record["arguments"], - type_transforms={ - "Tensor:0": "AnyTorchImmutableTensor", - }, - flag_transforms={ - ":0": ["kImmutableTensor"], - }) - ods_outs, return_type_flags = self._map_sigtypes( - reg_record["returns"], - type_transforms={ - "Tensor:0": "AnyTorchImmutableTensor", - }, - flag_transforms={ - ":0": ["kImmutableTensor"], - }) - self.ods_emitter.emit_opdef(ods_name, - op_name, - reg_record, - ods_ins=ods_ins, - ods_outs=ods_outs, - traits=["NoSideEffect"]) - self.impl_emitter.emit_kernel_methods(ods_name, - reg_record, - arg_type_flags=arg_type_flags, - return_type_flags=return_type_flags, - promote_trailing_out_tensor=True) + opdef = self.define_op( + kernel_sig=kernel_sig, + ods_name=ods_name, + op_name=op_name, + promote_trailing_out_tensor=promote_trailing_out_tensor, + traits=list(traits) + ["NoSideEffect"], + **kwargs) + opdef.arg_transforms(type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + opdef.return_transforms(type_transforms={ + "Tensor:0": "AnyTorchImmutableTensor", + }, + flag_transforms={ + ":0": ["kImmutableTensor"], + }) + opdef.emit() - def _get_reg_record(self, kernel_sig): + def get_reg_record(self, kernel_sig): """Gets the op-dict for a given registered op name. Args: @@ -212,8 +330,14 @@ class OpGenerator: raise ValueError(f"Could not find registry op matching '{kernel_sig}'. " f"Possible matches:\n {dym_message}") - def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str], - flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]: + def _map_sigtypes( + self, + siglist: List[Dict], + type_transforms: Dict[str, str], + flag_transforms: Dict[str, List[str]], + drop_indices: Sequence[int] = (), + override_types: Optional[Sequence[str]] = None, + ) -> List[Tuple[str]]: """Maps a list of signature entries to ods dags and flag lists. The torch signature list contains dicts that minimally have keys 'name' and @@ -233,15 +357,23 @@ class OpGenerator: - An ods dag list of (ods_name, ods_type) tuples - List of (torch_type, [conversion_flag]) for specifying conversions. """ + # Make sure any override types are sane. + if override_types: + assert len(override_types) == len(siglist), ( + "Mismatch override and actual types") # Generate to ods dag list. ods_dag_list = [] for i, sigitem in enumerate(siglist): + if i in drop_indices: + # Do not emit in ODS. + continue torch_name = sigitem["name"] - torch_type = sigitem["type"] + torch_type = (sigitem["type"] + if override_types is None else override_types[i]) # Look up the type transform. - ods_type = (type_transforms.get(f"{torch_type}:{i}") or - type_transforms.get(f":{i}") or - type_transforms.get(torch_type)) + ods_type = _first_non_none(type_transforms.get(f"{torch_type}:{i}"), + type_transforms.get(f":{i}"), + type_transforms.get(torch_type)) if not ods_type: raise ValueError(f"Signature item {i}, type {torch_type} did not match " f"a type transform {type_transforms}") @@ -250,16 +382,130 @@ class OpGenerator: # Generate the type conversion flags. type_flag_list = [] for i, sigitem in enumerate(siglist): - torch_type = sigitem["type"] + torch_type = (sigitem["type"] + if override_types is None else override_types[i]) # Look up the type transform. - flags = (flag_transforms.get(f"{torch_type}:{i}") or - flag_transforms.get(f":{i}") or flag_transforms.get(torch_type)) - if not flags: - flags = [] + if i in drop_indices: + flags = ["kDrop"] + else: + flags = _first_non_none(flag_transforms.get(f"{torch_type}:{i}"), + flag_transforms.get(f":{i}"), + flag_transforms.get(torch_type)) + if flags is None: + flags = [] type_flag_list.append((torch_type, flags)) return ods_dag_list, type_flag_list +class InflightOpDef: + + def __init__(self, + g: OpGenerator, + kernel_sig, + ods_name, + op_name, + traits=(), + alias_kernel_names=(), + promote_trailing_out_tensor=False, + override_arg_types=None, + override_return_types=None, + drop_arg_indices=(), + drop_return_indices=()): + super().__init__() + self.g = g + self.kernel_sig = kernel_sig + self.ods_name = ods_name + self.op_name = op_name + self.traits = list(traits) + self.alias_kernel_names = list(alias_kernel_names) + self.promote_trailing_out_tensor = promote_trailing_out_tensor + self.override_arg_types = override_arg_types + self.override_return_types = override_return_types + self.drop_arg_indices = drop_arg_indices + self.drop_return_indices = drop_return_indices + self.reg_record = g.get_reg_record(self.kernel_sig) + self._emitted = False + self._traceback = traceback.extract_stack()[0:-2] + + # Arg and flag transform dicts. + self.arg_type_transforms = dict() + self.arg_flag_transforms = dict() + self.return_type_transforms = dict() + self.return_flag_trasforms = dict() + + # Signature mapping. + self._sigs_mapped = False + self.ods_ins = None + self.ods_outs = None + self.arg_type_flags = None + self.return_type_flags = None + + def __del__(self): + if not self._emitted: + print("WARNING: Op defined but not emitted. Defined at:", file=sys.stderr) + for line in traceback.format_list(self._traceback): + sys.stderr.write(line) + + def transforms(self, type_transforms=None, flag_transforms=None): + self.arg_transforms(type_transforms=type_transforms, + flag_transforms=flag_transforms) + self.return_transforms(type_transforms=type_transforms, + flag_transforms=flag_transforms) + return self + + def arg_transforms(self, type_transforms=None, flag_transforms=None): + """Adds arg type and flag transforms dicts.""" + if type_transforms: + self.arg_type_transforms.update(type_transforms) + if flag_transforms: + self.arg_flag_transforms.update(flag_transforms) + return self + + def return_transforms(self, type_transforms=None, flag_transforms=None): + """Adds return type and flag transform dicts.""" + if type_transforms: + self.return_type_transforms.update(type_transforms) + if flag_transforms: + self.return_flag_trasforms.update(flag_transforms) + return self + + def map_signatures(self): + assert not self._sigs_mapped, "Signatures already mapped" + self._sigs_mapped = True + self.ods_ins, self.arg_type_flags = self.g._map_sigtypes( + self.reg_record["arguments"], + type_transforms=self.arg_type_transforms, + flag_transforms=self.arg_flag_transforms, + override_types=self.override_arg_types, + drop_indices=self.drop_arg_indices) + self.ods_outs, self.return_type_flags = self.g._map_sigtypes( + self.reg_record["returns"], + type_transforms=self.return_type_transforms, + flag_transforms=self.return_flag_trasforms, + override_types=self.override_return_types, + drop_indices=self.drop_return_indices) + return self + + def emit(self): + assert not self._emitted, "Op already emitted" + self._emitted = True + if not self._sigs_mapped: + self.map_signatures() + self.g.ods_emitter.emit_opdef(self.ods_name, + self.op_name, + self.reg_record, + ods_ins=self.ods_ins, + ods_outs=self.ods_outs, + traits=self.traits) + self.g.impl_emitter.emit_kernel_methods( + self.ods_name, + self.reg_record, + arg_type_flags=self.arg_type_flags, + return_type_flags=self.return_type_flags, + promote_trailing_out_tensor=self.promote_trailing_out_tensor, + alias_kernel_names=self.alias_kernel_names) + + class EmitterBase: _INDENT = " " @@ -355,7 +601,8 @@ class CCImplEmitter(EmitterBase): reg_record, arg_type_flags: List[Tuple[str, List[Tuple[str]]]], return_type_flags: List[Tuple[str, List[Tuple[str]]]], - promote_trailing_out_tensor=False): + promote_trailing_out_tensor=False, + alias_kernel_names: Sequence[str] = ()): # getTorchKernelMetadata() method. self.print( f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{") @@ -374,6 +621,9 @@ class CCImplEmitter(EmitterBase): with self.indent(): self.print("Torch::BuildKernelMetadata m;") self.print(f"m.kernelName = {self.quote(kernel_name)};") + for alias_kernel_name in alias_kernel_names: + self.print( + f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});") if promote_trailing_out_tensor: self.print("m.promoteTrailingOutTensor = true;") # Arg types/flags. @@ -393,7 +643,7 @@ class CCImplEmitter(EmitterBase): self.print("return m;") self.print("})();") self.print("return metadata;") - self.print("}") + self.print("}\n") def _format_cpp_str_initlist(self, strings): quoted = [self.quote(s) for s in strings] @@ -416,6 +666,13 @@ def snakecase_to_camelcase(ident: str): return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident)) +def _first_non_none(*args): + for arg in args: + if arg is not None: + return arg + return None + + def _load_ops_as_dict(): # Returns a list of dicts, each with a name that is a tuple of the form: # (kernel_signature, variant) diff --git a/include/npcomp/Dialect/ATen/IR/.gitignore b/include/npcomp/Dialect/ATen/IR/.gitignore new file mode 100644 index 000000000..6f78ac32d --- /dev/null +++ b/include/npcomp/Dialect/ATen/IR/.gitignore @@ -0,0 +1 @@ +ATenOpRegistrations.txt diff --git a/include/npcomp/Dialect/ATen/IR/ATenOps.td b/include/npcomp/Dialect/ATen/IR/ATenOps.td index 380eae21e..61d01a059 100644 --- a/include/npcomp/Dialect/ATen/IR/ATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/ATenOps.td @@ -61,55 +61,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>, } -// Our jit library only supports 6 argument convolutions, rather than 9 -// arguments supported by pytorch. This operation allows us to represent this -// limitation temporarily. -def aten_ConvolutionOp: aten_Op<"_convolution", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$weight, - AnyTensor:$bias, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation - ); - - let summary = "Convolution operator"; - let description = [{ - Convolution operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - uint64_t getOperandTransferVolume(unsigned int idx, bool read); - uint64_t getResultTransferVolume(unsigned int idx, bool read); - }]; -} - -// Our jit library only supports 6 argument convolutions, rather than 9 -// arguments supported by pytorch. This operation allows us to represent this -// limitation temporarily. -def aten_ConvolutionBackwardOp: aten_Op<"_convolution_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor:$dx, AnyTensor:$dw, AnyTensor:$db)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$input, - AnyTensor:$weight, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation - ); - - let summary = "ConvolutionBackward operator"; - let description = [{ - ConvolutionBackward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - - def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index 5896a0471..5b6d453e0 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -37,6 +37,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -55,6 +56,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata DivOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -73,6 +75,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -91,6 +94,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata MulOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -109,6 +113,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -127,6 +132,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -145,6 +151,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() { })(); return metadata; } + // ----------------------------------------------------------------------------- // Unary arithmetic ops // ----------------------------------------------------------------------------- @@ -167,6 +174,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata AcosOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -185,6 +193,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata AngleOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -203,6 +212,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata AsinOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -221,6 +231,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata AtanOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -239,6 +250,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata CeilOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -257,6 +269,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ConjOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -275,6 +288,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata CosOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -293,6 +307,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata CoshOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -311,6 +326,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -329,6 +345,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ErfOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -347,6 +364,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -365,6 +383,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -383,6 +402,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ExpOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -401,6 +421,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -419,6 +440,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata FloorOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -437,6 +459,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata FracOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -455,6 +478,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -473,6 +497,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata LogOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -491,6 +516,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata Log10Op::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -509,6 +535,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -527,6 +554,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata Log2Op::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -545,6 +573,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata NegOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -563,6 +592,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ReluOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -581,6 +611,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -599,6 +630,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata RoundOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -617,6 +649,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -635,6 +668,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -653,6 +687,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata SignOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -671,6 +706,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata SinOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -689,6 +725,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata SinhOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -707,6 +744,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -725,6 +763,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata TanOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -743,6 +782,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata TanhOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -761,6 +801,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() { })(); return metadata; } + Torch::KernelMetadata TruncOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } @@ -779,3 +820,184 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() { })(); return metadata; } + +// ----------------------------------------------------------------------------- +// NN ops +// ----------------------------------------------------------------------------- + +Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::convolution_overrideable"; + m.aliasKernelNames.push_back("aten::convolution"); + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::convolution_backward_overrideable"; + m.aliasKernelNames.push_back("aten::convolution_backward"); + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"}); + m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::_log_softmax"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "int", "bool"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::_log_softmax_backward_data"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +// ----------------------------------------------------------------------------- +// Loss function ops +// ----------------------------------------------------------------------------- + +Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::nll_loss_forward"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor", "Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::nll_loss_backward"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::nll_loss2d_forward"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor", "Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::nll_loss2d_backward"; + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + +Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::copy_"; + m.addArgTypes({"Tensor", "Tensor", "bool"}); + m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kDropReturnAndAliasArg0}); + return m; + })(); + return metadata; +} + diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index d25a8d1ee..34a0099bb 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -450,3 +450,147 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::convolution_overrideable"; + let arguments = (ins + AnyTorchImmutableTensor:$input, + AnyTorchImmutableTensor:$weight, + AnyTorchOptionalImmutableTensor:$bias, + AnyTorchIntListType:$stride, + AnyTorchIntListType:$padding, + AnyTorchIntListType:$dilation, + AnyTorchBoolType:$transposed, + AnyTorchIntListType:$output_padding, + AnyTorchIntType:$groups + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::convolution_backward_overrideable"; + let arguments = (ins + AnyTorchImmutableTensor:$grad_output, + AnyTorchImmutableTensor:$input, + AnyTorchImmutableTensor:$weight, + AnyTorchIntListType:$stride, + AnyTorchIntListType:$padding, + AnyTorchIntListType:$dilation, + AnyTorchBoolType:$transposed, + AnyTorchIntListType:$output_padding, + AnyTorchIntType:$groups, + AnyTorchBoolListType:$output_mask + ); + let results = (outs + AnyTorchOptionalImmutableTensor:$grad_input, + AnyTorchOptionalImmutableTensor:$grad_weight, + AnyTorchOptionalImmutableTensor:$grad_bias + ); +} + +def aten_LogSoftmaxOp: aten_Op<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::_log_softmax"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchIntType:$dim, + AnyTorchBoolType:$half_to_float + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::_log_softmax_backward_data"; + let arguments = (ins + AnyTorchImmutableTensor:$grad_output, + AnyTorchImmutableTensor:$output, + AnyTorchIntType:$dim, + AnyTorchImmutableTensor:$self + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +// ----------------------------------------------------------------------------- +// Loss function ops +// ----------------------------------------------------------------------------- + +def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::nll_loss_forward"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$target, + AnyTorchOptionalImmutableTensor:$weight, + AnyTorchIntType:$reduction, + AnyTorchIntType:$ignore_index + ); + let results = (outs + AnyTorchImmutableTensor:$output, + AnyTorchImmutableTensor:$total_weight + ); +} + +def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::nll_loss_backward"; + let arguments = (ins + AnyTorchImmutableTensor:$grad_output, + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$target, + AnyTorchOptionalImmutableTensor:$weight, + AnyTorchIntType:$reduction, + AnyTorchIntType:$ignore_index, + AnyTorchImmutableTensor:$total_weight + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::nll_loss2d_forward"; + let arguments = (ins + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$target, + AnyTorchOptionalImmutableTensor:$weight, + AnyTorchIntType:$reduction, + AnyTorchIntType:$ignore_index + ); + let results = (outs + AnyTorchImmutableTensor:$output, + AnyTorchImmutableTensor:$total_weight + ); +} + +def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::nll_loss2d_backward"; + let arguments = (ins + AnyTorchImmutableTensor:$grad_output, + AnyTorchImmutableTensor:$self, + AnyTorchImmutableTensor:$target, + AnyTorchOptionalImmutableTensor:$weight, + AnyTorchIntType:$reduction, + AnyTorchIntType:$ignore_index, + AnyTorchImmutableTensor:$total_weight + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + +def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::copy_"; + let arguments = (ins + AnyTorchMutableTensor:$self, + AnyTorchImmutableTensor:$src + ); + let results = (outs + ); +} + diff --git a/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td index 0e181f938..7a7ed7fbf 100644 --- a/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/LegacyGeneratedATenOps.td @@ -44,52 +44,6 @@ def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface }]; } -def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$weight, - AnyTensor:$bias, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$transposed, - AnyType:$output_padding, - AnyScalar:$groups - ); - let summary = "aten convolution_overrideable operator"; - let description = [{ - ConvolutionOverrideableOp - aten convolution_overrideable operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - -def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$input, - AnyTensor:$weight, - AnyType:$stride, - AnyType:$padding, - AnyType:$dilation, - AnyScalar:$transposed, - AnyType:$output_padding, - AnyScalar:$groups - ); - let summary = "aten convolution_backward_overrideable operator"; - let description = [{ - ConvolutionBackwardOverrideableOp - aten convolution_backward_overrideable operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( @@ -123,35 +77,6 @@ def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>, }]; } -def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyScalar:$dim, - AnyScalar:$half_to_float - ); - let summary = "aten _log_softmax operator"; - let description = [{ - LogSoftmaxOp - aten _log_softmax operator - }]; -} - -def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$output, - AnyScalar:$dim, - AnyTensor:$self - ); - let summary = "aten _log_softmax_backward_data operator"; - let description = [{ - LogSoftmaxBackwardDataOp - aten _log_softmax_backward_data operator - }]; -} - def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( @@ -445,86 +370,6 @@ def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>, }]; } -def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index - ); - let summary = "aten nll_loss_forward operator"; - let description = [{ - NllLossForwardOp - aten nll_loss_forward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - -def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index, - AnyTensor:$total_weight - ); - let summary = "aten nll_loss_backward operator"; - let description = [{ - NllLossBackwardOp - aten nll_loss_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - -def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor, AnyTensor)> { - let arguments = ( - ins AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index - ); - let summary = "aten nll_loss2d_forward operator"; - let description = [{ - NllLoss2dForwardOp - aten nll_loss2d_forward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - -def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>, - Results<(outs AnyTensor)> { - let arguments = ( - ins AnyTensor:$grad_output, - AnyTensor:$self, - AnyTensor:$target, - AnyTensor:$weight, - AnyScalar:$reduction, - AnyScalar:$ignore_index, - AnyTensor:$total_weight - ); - let summary = "aten nll_loss2d_backward operator"; - let description = [{ - NllLoss2dBackwardOp - aten nll_loss2d_backward operator - }]; - let extraClassDeclaration = [{ - std::map getStatistics(); - }]; -} - def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>, Results<(outs AnyTensor)> { let arguments = ( diff --git a/include/npcomp/Dialect/ATen/Transforms/ATenToStd.td b/include/npcomp/Dialect/ATen/Transforms/ATenToStd.td index f2bb2c077..2d4a697c6 100644 --- a/include/npcomp/Dialect/ATen/Transforms/ATenToStd.td +++ b/include/npcomp/Dialect/ATen/Transforms/ATenToStd.td @@ -12,15 +12,4 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" include "npcomp/Dialect/ATen/IR/ATenOps.td" -// The pytorch convolution operator has 9 arguments, but we only have a jit -// library that supports the first six at the moment. -def : Pat<(aten_ConvolutionOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6, - $a7, $a8, $a9), - (aten_ConvolutionOp $a1, $a2, $a3, $a4, $a5, $a6)>; - -def : Pat<(aten_ConvolutionBackwardOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6, - $a7, $a8, $a9), - (aten_ConvolutionBackwardOp $a1, $a2, $a3, $a4, $a5, $a6)>; - - #endif diff --git a/include/npcomp/Dialect/Torch/IR/OpInterfaces.h b/include/npcomp/Dialect/Torch/IR/OpInterfaces.h index 0d338e2ee..b86bac4b5 100644 --- a/include/npcomp/Dialect/Torch/IR/OpInterfaces.h +++ b/include/npcomp/Dialect/Torch/IR/OpInterfaces.h @@ -18,7 +18,7 @@ namespace Torch { /// Conversion rule to apply to a value (argument or return). namespace KernelValueConversion { -enum BitMask { +enum BitMask : uint32_t { // No coercion. kNone = 0, @@ -32,7 +32,16 @@ enum BitMask { // to a 0d tensor. kPromoteScalar = 8, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar) + // Drops the return value and aliases to argument 0. + // TODO: Remove this in favor of general alias metadata processing (note that + // the vast majority are this one case so it isn't so bad to have a special + // case for it if necessary). + kDropReturnAndAliasArg0 = 16, + + // Drops the argument/return. + kDrop = 32, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kDrop) }; LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); } // namespace KernelValueConversion @@ -74,6 +83,9 @@ struct BuildKernelMetadata : public KernelMetadata { SmallVector argConversions; SmallVector returnConversions; + /// Additional alias kernel names to match. + SmallVector aliasKernelNames; + void addArgConversions( std::initializer_list ilist) { argConversions.insert(argConversions.end(), ilist); diff --git a/include/npcomp/Dialect/Torch/IR/TorchBase.td b/include/npcomp/Dialect/Torch/IR/TorchBase.td index 6d4a4304c..1dad1bc0a 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchBase.td +++ b/include/npcomp/Dialect/Torch/IR/TorchBase.td @@ -72,6 +72,11 @@ def AnyTorchImmutableTensor : AnyTypeOf<[ AnyTensor, ], "allowable torch immutable tensor">; +def AnyTorchOptionalImmutableTensor : AnyTypeOf<[ + AnyTorchImmutableTensor, + Basicpy_NoneType, +], "allowable torch immutable tensor (or None)">; + def AnyTorchMutableTensor : AnyTypeOf<[ // "Numpy-style" mutable NDArray. While not offering the full generality // of a Torch tensor, it models the same access patterns and implies the @@ -95,7 +100,28 @@ def AnyTorchScalarType : AnyTypeOf<[ AnySignlessInteger, ], "Any primitive type suitable to be passed as a Torch Scalar">; +def AnyTorchBoolType : AnyTypeOf<[ + I1, + Basicpy_BoolType, +], "Any permissible bool type">; + +def AnyTorchBoolListType : AnyTypeOf<[ + Basicpy_ListType, + // TODO: Support typed list when available. +], "Any bool list type (bool[])">; + +def AnyTorchIntType : AnyTypeOf<[ + AnySignedInteger, + AnySignlessInteger, +], "Any primitive integer type suitable to be passed as a Torch 'int'">; + +def AnyTorchIntListType : AnyTypeOf<[ + Basicpy_ListType, + // TODO: Support typed list when available. +], "Any int list type (int[])">; + def AnyTorchType : AnyTypeOf<[ + AnyTorchBoolType, AnyTorchScalarType, AnyTorchTensorType, Basicpy_ListType, diff --git a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp index cbf4d1348..1d7aab620 100644 --- a/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp +++ b/lib/Dialect/ATen/IR/ATenDialectOpStats.cpp @@ -170,40 +170,6 @@ std::map BatchNormOp::getStatistics() { return toReturn; } -// _convolution -std::map ConvolutionOp::getStatistics() { - return getConv2dStatistics(this, /*groups*/ 1); -} -std::map ConvolutionOverrideableOp::getStatistics() { - // FIXME - auto co = cast(groups().getDefiningOp()); - auto ia = co.template getAttrOfType("value"); - uint64_t groups = ia.getValue().getZExtValue(); - - return getConv2dStatistics(this, groups); -} - -uint64_t ConvolutionOp::getOperandTransferVolume(unsigned int idx, bool read) { - return getConv2dOperandTransferVolume(this, idx, read); -} - -uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) { - return getConv2dResultTransferVolume(this, idx, write); -} - -// _convolution_backward -std::map ConvolutionBackwardOp::getStatistics() { - return getConv2dBackwardStatistics(*this, 1); -} -std::map -ConvolutionBackwardOverrideableOp::getStatistics() { - auto co = cast(groups().getDefiningOp()); - auto ia = co.template getAttrOfType("value"); - uint64_t groups = ia.getValue().getZExtValue(); - - return getConv2dBackwardStatistics(*this, groups); -} - // div_ std::map DivUnderOp::getStatistics() { @@ -559,35 +525,6 @@ std::map NativeBatchNormBackwardOp::getStatistics() { return toReturn; } -std::map NllLossForwardOp::getStatistics() { - std::map toReturn; - // FIXME: unimplemented - toReturn["reads"] = -1; - toReturn["writes"] = -1; - return toReturn; -} -std::map NllLossBackwardOp::getStatistics() { - std::map toReturn; - // FIXME: unimplemented - toReturn["reads"] = -1; - toReturn["writes"] = -1; - return toReturn; -} -std::map NllLoss2dForwardOp::getStatistics() { - std::map toReturn; - // FIXME: unimplemented - toReturn["reads"] = -1; - toReturn["writes"] = -1; - return toReturn; -} -std::map NllLoss2dBackwardOp::getStatistics() { - std::map toReturn; - // FIXME: unimplemented - toReturn["reads"] = -1; - toReturn["writes"] = -1; - return toReturn; -} - // std::map ReLUUnderOp::getStatistics() { // return getReLUOpStatistics(*this); // } diff --git a/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp b/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp index 92295daf6..8696081b6 100644 --- a/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp +++ b/lib/Dialect/ATen/Transforms/RecognizeKernelsPass.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "npcomp/Dialect/ATen/IR/ATenDialect.h" #include "npcomp/Dialect/ATen/Transforms/Passes.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h" @@ -28,6 +29,14 @@ using namespace mlir::NPCOMP::Torch; namespace { +bool isTorchTensorType(StringRef torchType) { + return torchType == "Tensor" || torchType == "Tensor?"; +} + +bool isTorchOptionalType(StringRef torchType) { + return torchType.endswith("?"); +} + struct TypeConversion { Type targetType; std::function()) + return TypeConversion{sourceMlirType, nullptr}; + // Already immutable. if (sourceMlirType.isa()) return TypeConversion{sourceMlirType, nullptr}; @@ -86,30 +101,51 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType, Type sourceMlirType) { using KVC = KernelValueConversion::BitMask; // Default trivial case. - if (sourceTorchType == targetTorchType && flag == 0) + if (sourceTorchType == targetTorchType && flag == 0) { + LLVM_DEBUG(llvm::dbgs() << " * Return types already match\n"); return TypeConversion{sourceMlirType, nullptr}; + } // Immutable tensor conversion. if (flag & KVC::kImmutableTensor) { - if (sourceTorchType != "Tensor" || targetTorchType != "Tensor") + LLVM_DEBUG(llvm::dbgs() + << " * Return conversion flag kImmutableTensor\n"); + if (!isTorchTensorType(sourceTorchType) || + !isTorchTensorType(targetTorchType)) { + LLVM_DEBUG(llvm::dbgs() + << " * Source or target not a Tensor type\n"); return None; + } // Already immutable. - if (sourceMlirType.isa()) + if (sourceMlirType.isa()) { + LLVM_DEBUG(llvm::dbgs() << " * Source is already immutable\n"); return TypeConversion{sourceMlirType, nullptr}; + } // Convert NdArray type. - if (auto ndArrayType = sourceMlirType.dyn_cast()) { + if (sourceMlirType.isa() && + isTorchOptionalType(targetTorchType)) { + LLVM_DEBUG(llvm::dbgs() << " * None Tensor type passthrough\n"); + return TypeConversion{sourceMlirType, nullptr}; + } else if (auto ndArrayType = + sourceMlirType.dyn_cast()) { auto tensorType = ndArrayType.toTensorType(); auto callback = [=](Location loc, Value newOpResultValue, PatternRewriter &rewriter) -> Value { return rewriter.create( loc, ndArrayType, newOpResultValue); }; + LLVM_DEBUG(llvm::dbgs() << " * Convert return type\n"); return TypeConversion{tensorType, callback}; + } else { + LLVM_DEBUG(llvm::dbgs() + << " * Return type is not a supported tensor type\n"); + return None; } } + LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n"); return None; } @@ -142,9 +178,16 @@ public: const BuildKernelMetadata &buildMetadata) { LLVM_DEBUG(llvm::dbgs() << "Register kernel call translation for: " << opName << "\n"); - CandidateTransformList &candidates = - kernelTransforms[buildMetadata.kernelName]; - candidates.emplace_back(opName, buildMetadata); + { + CandidateTransformList &candidates = + kernelTransforms[buildMetadata.kernelName]; + candidates.emplace_back(opName, buildMetadata); + } + + for (StringRef aliasKernelName : buildMetadata.aliasKernelNames) { + CandidateTransformList &candidates = kernelTransforms[aliasKernelName]; + candidates.emplace_back(opName, buildMetadata); + } } LogicalResult transformKernelCall(KernelCallOp kernelCall, @@ -229,6 +272,8 @@ public: "arg arity mismatch"); // Convert fixed return types. + using PostConversionCallback = std::function; + SmallVector postConversionCallbacks; struct ConversionInfo { Value originalValue; TypeConversion conversion; @@ -241,25 +286,49 @@ public: KVC flag = candidate.buildMetadata.getReturnConversion(i); Value sourceValue = kernelCall.getResult(i); Type sourceMlirType = kernelCall.getResultTypes()[i]; - auto conversion = convertTorchReturnType(sourceTorchType, targetTorchType, - flag, sourceMlirType); - if (!conversion) { - LLVM_DEBUG(llvm::dbgs() << " - Return type[" << i - << "] incompatible: source=" << sourceTorchType - << ", target=" << targetTorchType - << ", flag=" << flag << "\n"); - return failure(); + if (flag & KVC::kDropReturnAndAliasArg0) { + // Reduce result arity and alias any uses to arg0. + if (kernelCall.args().empty()) { + LLVM_DEBUG(llvm::dbgs() + << " - Cannot alias arg0 (no arguments)\n"); + return failure(); + } + Value arg0 = kernelCall.args()[0]; + postConversionCallbacks.push_back( + [sourceValue, arg0]() { sourceValue.replaceAllUsesWith(arg0); }); + } else { + // General, arity-preserving type conversion. + auto conversion = convertTorchReturnType( + sourceTorchType, targetTorchType, flag, sourceMlirType); + if (!conversion) { + LLVM_DEBUG(llvm::dbgs() + << " - Return type[" << i << "] incompatible: source=" + << sourceTorchType << ", target=" << targetTorchType + << ", flag=" << flag << "\n"); + return failure(); + } + resultTypes.push_back(conversion->targetType); + resultConversions.push_back({sourceValue, std::move(*conversion)}); } - resultTypes.push_back(conversion->targetType); - resultConversions.push_back({sourceValue, std::move(*conversion)}); } // Convert fixed arg types. SmallVector operandInfos; - for (size_t i = 0; i < fixedArgArity; ++i) { + for (size_t i = 0, operandIndex = 0; i < fixedArgArity; ++i) { + // Drop this arg? + if (candidate.buildMetadata.argConversions[i] & KVC::kDrop) + continue; + if (kernelCall.getNumOperands() <= operandIndex) { + LLVM_DEBUG(llvm::dbgs() + << " - Arg operand " << i + << " does not exist in kernel call (missing default?)\n"); + return failure(); + } + + // Normal type conversion of the operand. operandInfos.emplace_back(); ConversionInfo &info = operandInfos.back(); - info.originalValue = kernelCall.getOperand(i); + info.originalValue = kernelCall.getOperand(operandIndex++); Type sourceMlirType = info.originalValue.getType(); auto conversion = convertTorchArgType( /*sourceTorchType=*/sourceMetadata.argTypes[i], @@ -312,6 +381,10 @@ public: origOpResultValue.replaceAllUsesWith(convertedValue); } + // Post conversion callbacks. + for (auto &callback : postConversionCallbacks) + callback(); + // Done. rewriter.eraseOp(kernelCall); return success(); diff --git a/test/Dialect/ATen/aten_conv2d.mlir b/test/Dialect/ATen/aten_conv2d.mlir deleted file mode 100644 index 4f9f500b0..000000000 --- a/test/Dialect/ATen/aten_conv2d.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-_convolution-0": { -// CHECK-NEXT: "activation_in": 32768, -// CHECK-NEXT: "activation_out": 65536, -// CHECK-NEXT: "ops:+": 65536, -// CHECK-NEXT: "ops:MAC": 6422528, -// CHECK-NEXT: "parameters_in": 1584, -// CHECK-NEXT: "reads": 34352, -// CHECK-NEXT: "writes": 65536 - -module { - func @graph(%arg0: tensor<1x2x128x128xf32>, %arg1: tensor<16x2x7x7xf32>, %arg2: tensor<16xf32>) -> tensor<1x16x64x64xf32> { - %0 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list - %1 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list - %2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list - %3 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %4 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi64>} : () -> !aten.list - %5 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %6 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %7 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1 - %8 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1 - %9 = "aten._convolution"(%arg0, %arg1, %arg2, %0, %1, %2) : (tensor<1x2x128x128xf32>, tensor<16x2x7x7xf32>, tensor<16xf32>, !aten.list, !aten.list, !aten.list) -> tensor<1x16x64x64xf32> - "std.return"(%9) : (tensor<1x16x64x64xf32>) -> () - } -} diff --git a/test/Dialect/ATen/aten_conv2d_back.mlir b/test/Dialect/ATen/aten_conv2d_back.mlir deleted file mode 100644 index a25dace07..000000000 --- a/test/Dialect/ATen/aten_conv2d_back.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s -// CHECK-LABEL: "L0-convolution_backward_overrideable-0": { -// CHECK-NEXT: "activation_in": 5568, -// CHECK-NEXT: "grad": 5380, -// CHECK-NEXT: "ops:+": 768, -// CHECK-NEXT: "ops:MAC": 345600, -// CHECK-NEXT: "parameters_in": 576, -// CHECK-NEXT: "reads": 6144, -// CHECK-NEXT: "writes": 5380 -// CHECK-NEXT: } - -// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION -// CHECK-CONVERSION-LABEL: @graph -module { - func @graph(%arg0: tensor<3x4x8x8xf32>, %arg1: tensor<3x16x10x10xf32>, %arg2: tensor<4x16x3x3xf32>) -> tensor<4x16x3x3xf32> { - %0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %2 = "aten.constant"() {type = "bool", value = false} : () -> i1 - %3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %10:3 = "aten.convolution_backward_overrideable"(%arg0, %arg1, %arg2, %0, %1, %0, %2, %1, %3) {layer_name = "L5-convolution_backward_overrideable-0"} : (tensor<3x4x8x8xf32>, tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, !aten.list, !aten.list, !aten.list, i1, !aten.list, i32) -> (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>) - return %10#1 : tensor<4x16x3x3xf32> - } -} diff --git a/test/Dialect/ATen/lenet_fwd.mlir b/test/Dialect/ATen/lenet_fwd.mlir deleted file mode 100644 index 746a03fcf..000000000 --- a/test/Dialect/ATen/lenet_fwd.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION -// CHECK-CONVERSION-LABEL: @graph - -module { - func @graph(%arg0: tensor<10xf32>, %arg1: tensor<128xf32>, %arg2: tensor<4x1x28x28xf32>, %arg3: tensor<32x1x3x3xf32>, %arg4: tensor<32xf32>, %arg5: tensor<64x32x3x3xf32>, %arg6: tensor<64xf32>, %arg7: tensor<128x9216xf32>, %arg8: tensor<10x128xf32>) -> tensor<4x10xf32> { - %0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list - %1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list - %2 = "aten.constant"() {type = "bool", value = false} : () -> i1 - %3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32 - %4 = "aten.constant"() {type = "bool", value = true} : () -> i1 - %5 = "aten._convolution"(%arg2, %arg3, %arg4, %0, %1, %0) {layer_name = "L0-_convolution-0"} : (tensor<4x1x28x28xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>, !aten.list, !aten.list, !aten.list) -> tensor<4x32x26x26xf32> - %6 = "aten.relu"(%5) {layer_name = "L1-relu-0"} : (tensor<4x32x26x26xf32>) -> tensor<4x32x26x26xf32> - %7 = "aten._convolution"(%6, %arg5, %arg6, %0, %1, %0) {layer_name = "L2-_convolution-1"} : (tensor<4x32x26x26xf32>, tensor<64x32x3x3xf32>, tensor<64xf32>, !aten.list, !aten.list, !aten.list) -> tensor<4x64x24x24xf32> - %8 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi32>} : () -> !aten.list - %9:2 = "aten.max_pool2d_with_indices"(%7, %8, %8, %1, %0, %2) {layer_name = "L3-max_pool2d_with_indices-0"} : (tensor<4x64x24x24xf32>, !aten.list, !aten.list, !aten.list, !aten.list, i1) -> (tensor<4x64x12x12xf32>, tensor<4x64x12x12xi64>) - %10 = "aten.constant"() {type = "List[i32]", value = dense<[4, 9216]> : vector<2xi32>} : () -> !aten.list - %11 = "aten.view"(%9#0, %10) {layer_name = "L4-view-0"} : (tensor<4x64x12x12xf32>, !aten.list) -> tensor<4x9216xf32> - %12 = "aten.t"(%arg7) {layer_name = "L5-t-0"} : (tensor<128x9216xf32>) -> tensor<9216x128xf32> - %13 = "aten.addmm"(%arg1, %11, %12, %3, %3) {layer_name = "L6-addmm-0"} : (tensor<128xf32>, tensor<4x9216xf32>, tensor<9216x128xf32>, i32, i32) -> tensor<4x128xf32> - %14 = "aten.relu"(%13) {layer_name = "L7-relu-1"} : (tensor<4x128xf32>) -> tensor<4x128xf32> - %15 = "aten.t"(%arg8) {layer_name = "L8-t-1"} : (tensor<10x128xf32>) -> tensor<128x10xf32> - %16 = "aten.addmm"(%arg0, %14, %15, %3, %3) {layer_name = "L9-addmm-1"} : (tensor<10xf32>, tensor<4x128xf32>, tensor<128x10xf32>, i32, i32) -> tensor<4x10xf32> - %17 = "aten._log_softmax"(%16, %3, %2) {layer_name = "L10-_log_softmax-0"} : (tensor<4x10xf32>, i32, i1) -> tensor<4x10xf32> - return %17 : tensor<4x10xf32> - } -} diff --git a/test/Dialect/ATen/recognize_aten_kernels.mlir b/test/Dialect/ATen/recognize_aten_kernels.mlir index 9cbb973ad..5963e2094 100644 --- a/test/Dialect/ATen/recognize_aten_kernels.mlir +++ b/test/Dialect/ATen/recognize_aten_kernels.mlir @@ -13,3 +13,86 @@ func @graph(%arg0: !numpy.ndarray<*:?>, %arg1 : !numpy.ndarray<*:?>, %arg2 : si6 // CHECK: return %[[RESULT_MUT]] return %0 : !numpy.ndarray<*:?> } + + +// ----- +// CHECK-LABEL: func @nll_loss2d_forward +// Contains a Tensor? type mapped to None. +func @nll_loss2d_forward( + %arg0: !numpy.ndarray<[3,4,8,8]:f32>, + %arg1: !numpy.ndarray<[3,8,8]:i64>, + %arg2: !basicpy.NoneType, + %arg3: i64, + %arg4: i64) -> (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) { + // CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0 + // CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1 + // CHECK: %[[TOUTPUT:.*]], %[[TTOTAL_WEIGHT:.*]] = "aten.nll_loss2d_forward"(%[[TARG0]], %[[TARG1]], %arg2, %arg3, %arg4) : (tensor<3x4x8x8xf32>, tensor<3x8x8xi64>, !basicpy.NoneType, i64, i64) -> (tensor, tensor) + // CHECK: %[[AOUTPUT:.*]] = numpy.create_array_from_tensor %[[TOUTPUT]] + // CHECK: %[[ATOTAL_WEIGHT:.*]] = numpy.create_array_from_tensor %[[TTOTAL_WEIGHT]] + %0:2 = torch.kernel_call "aten::nll_loss2d_forward" + %arg0, %arg1, %arg2, %arg3, %arg4 : + (!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,8,8]:i64>, !basicpy.NoneType, i64, i64) -> + (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) + {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor"]} + // CHECK: return %[[AOUTPUT]], %[[ATOTAL_WEIGHT]] + return %0#0, %0#1 : !numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32> +} + +// ----- +// CHECK-LABEL: func @convolution +// Contains a Tensor?, bool, int and list types. +func @convolution( + %arg0: !numpy.ndarray<[3,16,10,10]:f32>, %arg1: !numpy.ndarray<[4,16,3,3]:f32>, + %arg2: !numpy.ndarray<[4]:f32>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType, + %arg5: !basicpy.ListType, %arg6: i1, %arg7: !basicpy.ListType, %arg8: i64) -> !numpy.ndarray<[3,4,8,8]:f32> { + // CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0 + // CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1 + // CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2 + // CHECK: %[[TRESULT:.*]] = "aten.convolution"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> tensor<3x4x8x8xf32> + %0 = torch.kernel_call "aten::convolution" + %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 : + (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, + !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, + !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32> + {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} + return %0 : !numpy.ndarray<[3,4,8,8]:f32> +} + +// ----- +// CHECK-LABEL: func @convolution_backward +// Interesting because it has optional tensor returns. +func @convolution_backward( + %arg0: !numpy.ndarray<[3,4,8,8]:f32>, + %arg1: !numpy.ndarray<[3,16,10,10]:f32>, + %arg2: !numpy.ndarray<[4,16,3,3]:f32>, + %arg3: !basicpy.ListType, + %arg4: !basicpy.ListType, + %arg5: !basicpy.ListType, + %arg6: i1, + %arg7: !basicpy.ListType, + %arg8: i64, + %arg9: !basicpy.ListType) -> (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) { + // CHECK: %[[GRAD_INPUT:.*]], %[[GRAD_WEIGHT:.*]], %[[GRAD_BIAS:.*]] = "aten.convolution_backward" + // Note that this kernel call masks out the input gradients, which will return as NoneType. + %0:3 = torch.kernel_call "aten::convolution_backward" + %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 : + (!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64, !basicpy.ListType) -> + (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {sigArgTypes = ["Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor", "Tensor"]} + // CHECK: %[[AGRAD_WEIGHT:.*]] = numpy.create_array_from_tensor %[[GRAD_WEIGHT]] + // CHECK: %[[AGRAD_BIAS:.*]] = numpy.create_array_from_tensor %[[GRAD_BIAS]] + // Key thing: The return returns the raw NoneType from the masked input gradient + // and it does not get converted to an array. + // CHECK: return %[[GRAD_INPUT]], %[[AGRAD_WEIGHT]], %[[AGRAD_BIAS]] + return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32> +} + +// ----- +// CHECK-LABEL: func @copy_inplace +// Mutable/in-place op conversion, dropping result. +func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> { + // CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1 + // CHECK: "aten.copy.inplace"(%arg0, %[[TARG1]]) : (!numpy.ndarray<[4]:f32>, tensor<4xf32>) -> () + %0 = torch.kernel_call "aten::copy_" %arg0, %arg1 : (!numpy.ndarray<[4]:f32>, !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {sigArgTypes = ["Tensor", "Tensor", "bool"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} + // CHECK: return %arg0 + return %0 : !numpy.ndarray<[4]:f32> +}