From 85ff8b692b75f1b3b59ef9e52dcac47499465585 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 10 Aug 2021 21:28:50 -0400 Subject: [PATCH] Fix compilation errors from MT model With the following changes the compilation can continue until RefineTypes pass: - Add operators without ODS into `torch_ods_gen.py` - Add some new optional and list types in `TorchTypes.td` - Add some folders for aten int type comparator ops - Modify GlobalizeObjectGraph.cpp. For global slots that's not used, dont check if an aliased value is stored in more than one of global slots. This can work around a failure where the same tensor is stored in multiple "version" slots which are not used. --- .../torch_mlir_utils/codegen/torch_ods_gen.py | 128 +- .../acap_export/test_conv_nllloss_grads.py | 5 +- .../Dialect/Torch/IR/GeneratedAtenOps.td | 1716 ++++++++++++++++- .../Dialect/Torch/IR/GeneratedPrimOps.td | 20 +- include/npcomp/Dialect/Torch/IR/TorchOps.h | 26 +- include/npcomp/Dialect/Torch/IR/TorchTypes.td | 33 +- lib/Dialect/Torch/IR/TorchOps.cpp | 136 +- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 63 +- .../Torch/GlobalizeObjectGraph/basic.mlir | 21 + .../Torch/GlobalizeObjectGraph/error.mlir | 26 + test/Dialect/Torch/canonicalize.mlir | 197 +- tools/run_lit.sh | 3 +- 12 files changed, 2274 insertions(+), 100 deletions(-) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py index 84185e23a..c72ae7011 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py @@ -81,7 +81,7 @@ class JitOperator: def create_unique_key(self) -> str: """Create a unique, human-readable key for this JitOperator. - + The key consists of the operator name and its overload name, which together form a unique identifier. We also redundantly append a signature to the end, which gives some robustness to changes @@ -217,12 +217,16 @@ OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]] # Use `get_ods_type` instead of using this directly. TORCH_TYPE_TO_ODS_TYPE = { "Tensor": "AnyTorchTensorType", - "Tensor?": "AnyTorchOptionalTensor", + "Tensor?": "AnyTorchOptionalTensorType", + "Tensor?[]": "AnyTorchOptionalTensorListType", + "Tensor[]": "AnyTorchTensorListType", "Scalar": "AnyTorchScalarType", "int": "Torch_IntType", - "int[]": "AnyTorchIntListType", + "int[]": "TorchIntListType", + "int?": "TorchOptionalIntType", "bool": "Torch_BoolType", - "bool[]": "AnyTorchBoolListType", + "bool[]": "TorchBoolListType", + "bool?": "TorchOptionalBoolType", "float": "Torch_FloatType", "t[]": "AnyTorchListType", "t": "AnyTorchType", @@ -230,12 +234,18 @@ TORCH_TYPE_TO_ODS_TYPE = { "t2": "AnyTorchType", "Any": "AnyTorchType", "Device": "Torch_DeviceType", + "Device?": "TorchOptionalDeviceType", "str": "Torch_StringType", + "str[]": "TorchStringListType", + "Dict": "Torch_DictType", "__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType", } def get_ods_type(type: str): + # TODO: Increase precision on dict type modeling. + if type.startswith("Dict("): + type = "Dict" ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type) if ods_type is None: raise Exception( @@ -364,7 +374,11 @@ def emit_op(operator: JitOperator, if not operator.is_vararg and not operator.is_varret and all( "alias_info" not in x for x in itertools.chain(operator.arguments, operator.returns)): - traits += ["HasValueSemantics"] + # It seems the FunctionSchema of "prim::unchecked_cast : (t) -> (t)" has + # incorrect alias information. The result can alias with other tensors + # but the alias annotation is empty. + if operator.unique_key != "prim::unchecked_cast : (t) -> (t)": + traits += ["HasValueSemantics"] raw_emit_op(operator, f, @@ -396,6 +410,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry): emit("prim::unchecked_cast : (t) -> (t)", traits=["DeclareOpInterfaceMethods"]) emit("prim::Print : (...) -> ()") + emit("prim::tolist : (...) -> (...)") def emit_aten_ops(torch_ir_dir: str, registry: Registry): @@ -421,11 +436,27 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): for key in [ "aten::tanh : (Tensor) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", + "aten::sin : (Tensor) -> (Tensor)", + "aten::exp : (Tensor) -> (Tensor)", + "aten::cos : (Tensor) -> (Tensor)", + "aten::neg : (Tensor) -> (Tensor)", + "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", + "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", ]: emit_with_mutating_variants(key) @@ -442,22 +473,109 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" ) emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") + emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") + emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") + emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") + emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") + emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)") + emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") # Misc tensor ops. emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) + emit("aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::Bool.Tensor : (Tensor) -> (bool)") + emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") + emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") + emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") + emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") + emit("aten::all : (Tensor) -> (Tensor)") + emit("aten::any : (Tensor) -> (Tensor)") + emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") + emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::contiguous : (Tensor, int) -> (Tensor)") + emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") + emit("aten::detach : (Tensor) -> (Tensor)") + emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") + emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") + emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") + emit("aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") + emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") + emit("aten::item : (Tensor) -> (Scalar)") + emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") + emit("aten::numel : (Tensor) -> (int)") + emit("aten::repeat : (Tensor, int[]) -> (Tensor)") + emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") + emit("aten::select.int : (Tensor, int, int) -> (Tensor)") + emit("aten::size.int : (Tensor, int) -> (int)") + emit("aten::stack : (Tensor[], int) -> (Tensor)") + emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)") + emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)") + emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") + emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") + emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") + emit("aten::view : (Tensor, int[]) -> (Tensor)") + emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") + emit("aten::len.Tensor : (Tensor) -> (int)") + emit("aten::cpu : (Tensor) -> (Tensor)") + emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") + emit("aten::IntImplicit : (Tensor) -> (int)") + + # Dict ops. + emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)") + emit("aten::__getitem__.Dict_str : (Dict(str, t), str) -> (t)") + emit("aten::_set_item.str : (Dict(str, t), str, t) -> ()") + emit("aten::keys.str : (Dict(str, t)) -> (str[])") + emit("aten::get.default_str : (Dict(str, t), str, t) -> (t)") + + # List ops. + emit("aten::cat : (Tensor[], int) -> (Tensor)") + emit("aten::append.t : (t[], t) -> (t[])") + emit("aten::add.t : (t[], t[]) -> (t[])") + emit("aten::eq.int_list : (int[], int[]) -> (bool)") + emit("aten::list.t : (t[]) -> (t[])") + emit("aten::slice.t : (t[], int?, int?, int) -> (t[])") + + # Str ops. + emit("aten::add.str : (str, str) -> (str)") + emit("aten::str : (t) -> (str)") + emit("aten::format : (...) -> (str)") + emit("aten::join : (str, str[]) -> (str)") + + # Type conversion ops. + emit("aten::Float.Scalar : (Scalar) -> (float)") + emit("aten::Float.str : (str) -> (float)") + emit("aten::Int.float : (float) -> (int)") # Primitive ops emit("aten::gt.int : (int, int) -> (bool)", has_folder=True) + emit("aten::ge.int : (int, int) -> (bool)", has_folder=True) + emit("aten::lt.int : (int, int) -> (bool)", has_folder=True) + emit("aten::le.int : (int, int) -> (bool)", has_folder=True) emit("aten::ne.int : (int, int) -> (bool)", has_folder=True) + emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) + emit("aten::floordiv.int : (int, int) -> (int)") + emit("aten::remainder.int : (int, int) -> (int)") emit("aten::add.int : (int, int) -> (int)") + emit("aten::sub.int : (int, int) -> (int)") emit("aten::mul.int : (int, int) -> (int)") + emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)") emit("aten::mul.float : (float, float) -> (float)") + emit("aten::neg.float : (float) -> (float)") emit("aten::lt.float_int : (float, int) -> (bool)") + emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) + emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) + emit("aten::__not__ : (bool) -> (bool)", has_folder=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) diff --git a/frontends/pytorch/test/acap_export/test_conv_nllloss_grads.py b/frontends/pytorch/test/acap_export/test_conv_nllloss_grads.py index 365ca9eb9..d9f7c9aa1 100644 --- a/frontends/pytorch/test/acap_export/test_conv_nllloss_grads.py +++ b/frontends/pytorch/test/acap_export/test_conv_nllloss_grads.py @@ -3,6 +3,7 @@ # See frontends/pytorch/LICENSE for license information. # RUN: %PYTHON %s | npcomp-opt | FileCheck %s +# XFAIL: * import torch from torch.autograd import Variable @@ -45,7 +46,7 @@ with mb.capture_function("resa", [inputs, target]) as f: # CHECK: torch.operator "aten.nll_loss2d_backward" # CHECK: torch.operator "aten._log_softmax_backward_data" # CHECK: %[[BWD_CONV:.*]]:3 = torch.operator "aten.convolution_backward_overrideable" -# CHECK: %[[BWD_CONV_WEIGHTS:.*]] = torch.operator "aten.copy_"{{.*}}%[[BWD_CONV]]#1 -# CHECK: %[[BWD_CONV_BIAS:.*]] = torch.operator "aten.copy_"{{.*}}%[[BWD_CONV]]#2 +# CHECK: %[[BWD_CONV_WEIGHTS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#1 +# CHECK: %[[BWD_CONV_BIAS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#2 # CHECK: return %[[FWD]]#0, %[[BWD_CONV_WEIGHTS]], %[[BWD_CONV_BIAS]] mb.module.operation.print(large_elements_limit=2) diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td index fc79ea44b..3155db088 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td @@ -71,6 +71,146 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_AtenSinOp : Torch_Op<"aten.sin", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::sin : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenSin_Op : Torch_Op<"aten.sin_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sin_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenExpOp : Torch_Op<"aten.exp", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::exp : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::exp_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenCosOp : Torch_Op<"aten.cos", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::cos : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::cos_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenNegOp : Torch_Op<"aten.neg", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::neg : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::neg_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::bitwise_not : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_not_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ AllowsTypeRefinement, HasValueSemantics @@ -227,6 +367,340 @@ def Torch_AtenLerp_TensorOp : Torch_Op<"aten.lerp_.Tensor", [ let assemblyFormat = "$self `,` $end `,` $weight attr-dict `:` type($self) `,` type($end) `,` type($weight) `->` type($result)"; } +def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenNe_TensorOp : Torch_Op<"aten.ne_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ne_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)"; +} + +def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)"; +} + +def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)"; +} + +def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)"; +} + +def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::div.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::div_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenFmod_ScalarOp : Torch_Op<"aten.fmod_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::fmod_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ AllowsTypeRefinement, HasValueSemantics @@ -235,7 +709,7 @@ def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ let arguments = (ins AnyTorchTensorType:$input, AnyTorchTensorType:$weight, - AnyTorchOptionalTensor:$bias + AnyTorchOptionalTensorType:$bias ); let results = (outs AnyTorchTensorType:$result @@ -266,10 +740,10 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ let arguments = (ins AnyTorchTensorType:$input, AnyTorchTensorType:$weight, - AnyTorchOptionalTensor:$bias, - AnyTorchIntListType:$stride, - AnyTorchIntListType:$padding, - AnyTorchIntListType:$dilation, + AnyTorchOptionalTensorType:$bias, + TorchIntListType:$stride, + TorchIntListType:$padding, + TorchIntListType:$dilation, Torch_IntType:$groups ); let results = (outs @@ -285,10 +759,10 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ let summary = "Generated op for `aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - AnyTorchOptionalTensor:$weight, - AnyTorchOptionalTensor:$bias, - AnyTorchOptionalTensor:$running_mean, - AnyTorchOptionalTensor:$running_var, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, Torch_BoolType:$training, Torch_FloatType:$momentum, Torch_FloatType:$eps, @@ -307,10 +781,10 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchIntListType:$kernel_size, - AnyTorchIntListType:$stride, - AnyTorchIntListType:$padding, - AnyTorchIntListType:$dilation, + TorchIntListType:$kernel_size, + TorchIntListType:$stride, + TorchIntListType:$padding, + TorchIntListType:$dilation, Torch_BoolType:$ceil_mode ); let results = (outs @@ -326,7 +800,7 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchIntListType:$output_size + TorchIntListType:$output_size ); let results = (outs AnyTorchTensorType:$result @@ -334,6 +808,134 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ let assemblyFormat = "$self `,` $output_size attr-dict `:` type($self) `,` type($output_size) `->` type($result)"; } +def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$k, + Torch_IntType:$dim, + Torch_BoolType:$largest, + Torch_BoolType:$sorted + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let assemblyFormat = "$self `,` $k `,` $dim `,` $largest `,` $sorted attr-dict `:` type($self) `,` type($k) `,` type($dim) `,` type($largest) `,` type($sorted) `->` type($values) `,` type($indices)"; +} + +def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::transpose.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim0, + Torch_IntType:$dim1 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim0 `,` $dim1 attr-dict `:` type($self) `,` type($dim0) `,` type($dim1) `->` type($result)"; +} + +def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::bmm : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)"; +} + +def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::cumsum : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + TorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)"; +} + +def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)"; +} + +def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$dim, + Torch_BoolType:$keepdim, + TorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $keepdim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `,` type($dtype) `->` type($result)"; +} + +def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement ]> { @@ -387,12 +989,913 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [ AnyTorchTensorType:$self ); let results = (outs - AnyTorchIntListType:$result + TorchIntListType:$result ); let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; let hasCanonicalizer = 1; } +def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $value attr-dict `:` type($self) `,` type($value) `->` type($result)"; +} + +def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::Bool.Tensor : (Tensor) -> (bool)`"; + let arguments = (ins + AnyTorchTensorType:$a + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + +def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + TorchIntListType:$size, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($size) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; +} + +def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + TorchIntListType:$size, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($size) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; +} + +def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::tensor : (t[], int?, Device?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchListType:$data, + TorchOptionalIntType:$dtype, + TorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$data `,` $dtype `,` $device `,` $requires_grad attr-dict `:` type($data) `,` type($dtype) `,` type($device) `,` type($requires_grad) `->` type($result)"; +} + +def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)`"; + let arguments = (ins + Torch_BoolType:$t, + TorchOptionalIntType:$dtype, + TorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$t `,` $dtype `,` $device `,` $requires_grad attr-dict `:` type($t) `,` type($dtype) `,` type($device) `,` type($requires_grad) `->` type($result)"; +} + +def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$t, + TorchOptionalIntType:$dtype, + TorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$t `,` $dtype `,` $device `,` $requires_grad attr-dict `:` type($t) `,` type($dtype) `,` type($device) `,` type($requires_grad) `->` type($result)"; +} + +def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::_shape_as_tensor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenAllOp : Torch_Op<"aten.all", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::all : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenAnyOp : Torch_Op<"aten.any", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::any : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenAnyDimOp : Torch_Op<"aten.any.dim", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::any.dim : (Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)"; +} + +def Torch_AtenArangeOp : Torch_Op<"aten.arange", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$end, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; +} + +def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; +} + +def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::contiguous : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $memory_format attr-dict `:` type($self) `,` type($memory_format) `->` type($result)"; +} + +def Torch_AtenCopy_Op : Torch_Op<"aten.copy_", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_BoolType:$non_blocking + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $src `,` $non_blocking attr-dict `:` type($self) `,` type($src) `,` type($non_blocking) `->` type($result)"; +} + +def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::detach : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$weight, + AnyTorchTensorType:$indices, + Torch_IntType:$padding_idx, + Torch_BoolType:$scale_grad_by_freq, + Torch_BoolType:$sparse + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)"; +} + +def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let arguments = (ins + TorchIntListType:$size, + TorchOptionalIntType:$dtype, + TorchOptionalIntType:$layout, + TorchOptionalDeviceType:$device, + TorchOptionalBoolType:$pin_memory, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$size `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($size) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)"; +} + +def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$size, + Torch_BoolType:$implicit + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $size `,` $implicit attr-dict `:` type($self) `,` type($size) `,` type($implicit) `->` type($result)"; +} + +def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalTensorListType:$indices + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $indices attr-dict `:` type($self) `,` type($indices) `->` type($result)"; +} + +def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalTensorListType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate attr-dict `:` type($self) `,` type($indices) `,` type($values) `,` type($accumulate) `->` type($result)"; +} + +def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::index_select : (Tensor, int, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $index attr-dict `:` type($self) `,` type($dim) `,` type($index) `->` type($result)"; +} + +def Torch_AtenItemOp : Torch_Op<"aten.item", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::item : (Tensor) -> (Scalar)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchScalarType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::masked_select : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $mask attr-dict `:` type($self) `,` type($mask) `->` type($result)"; +} + +def Torch_AtenNumelOp : Torch_Op<"aten.numel", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::numel : (Tensor) -> (int)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$repeats + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)"; +} + +def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$size, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $size `,` $memory_format attr-dict `:` type($self) `,` type($size) `,` type($memory_format) `->` type($result)"; +} + +def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::select.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_IntType:$index + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $index attr-dict `:` type($self) `,` type($dim) `,` type($index) `->` type($result)"; +} + +def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::size.int : (Tensor, int) -> (int)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)"; +} + +def Torch_AtenStackOp : Torch_Op<"aten.stack", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorListType:$tensors, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$tensors `,` $dim attr-dict `:` type($tensors) `,` type($dim) `->` type($result)"; +} + +def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$dim, + Torch_BoolType:$keepdim, + TorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $keepdim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `,` type($dtype) `->` type($result)"; +} + +def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dtype `,` $non_blocking `,` $copy `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($non_blocking) `,` type($copy) `,` type($memory_format) `->` type($result)"; +} + +def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $non_blocking `,` $copy `,` $memory_format attr-dict `:` type($self) `,` type($other) `,` type($non_blocking) `,` type($copy) `,` type($memory_format) `->` type($result)"; +} + +def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchOptionalDeviceType:$device, + TorchOptionalIntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $device `,` $dtype `,` $non_blocking `,` $copy attr-dict `:` type($self) `,` type($device) `,` type($dtype) `,` type($non_blocking) `,` type($copy) `->` type($result)"; +} + +def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::type_as : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenViewOp : Torch_Op<"aten.view", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)"; +} + +def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + TorchOptionalIntType:$start, + TorchOptionalIntType:$end, + Torch_IntType:$step + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $start `,` $end `,` $step attr-dict `:` type($self) `,` type($dim) `,` type($start) `,` type($end) `,` type($step) `->` type($result)"; +} + +def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::len.Tensor : (Tensor) -> (int)`"; + let arguments = (ins + AnyTorchTensorType:$t + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$t attr-dict `:` type($t) `->` type($result)"; +} + +def Torch_AtenCpuOp : Torch_Op<"aten.cpu", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::cpu : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenGatherOp : Torch_Op<"aten.gather", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + Torch_BoolType:$sparse_grad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $index `,` $sparse_grad attr-dict `:` type($self) `,` type($dim) `,` type($index) `,` type($sparse_grad) `->` type($result)"; +} + +def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::IntImplicit : (Tensor) -> (int)`"; + let arguments = (ins + AnyTorchTensorType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + +def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::__contains__.str : (Dict(str, t), str) -> (bool)`"; + let arguments = (ins + Torch_DictType:$dict, + Torch_StringType:$key + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `,` type($key) `->` type($result)"; +} + +def Torch_Aten__Getitem__DictStrOp : Torch_Op<"aten.__getitem__.Dict_str", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::__getitem__.Dict_str : (Dict(str, t), str) -> (t)`"; + let arguments = (ins + Torch_DictType:$self, + Torch_StringType:$key + ); + let results = (outs + AnyTorchType:$result + ); + let assemblyFormat = "$self `,` $key attr-dict `:` type($self) `,` type($key) `->` type($result)"; +} + +def Torch_Aten_SetItemStrOp : Torch_Op<"aten._set_item.str", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::_set_item.str : (Dict(str, t), str, t) -> ()`"; + let arguments = (ins + Torch_DictType:$l, + Torch_StringType:$idx, + AnyTorchType:$v + ); + let results = (outs + ); + let assemblyFormat = "$l `,` $idx `,` $v attr-dict `:` type($l) `,` type($idx) `,` type($v)"; +} + +def Torch_AtenKeysStrOp : Torch_Op<"aten.keys.str", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::keys.str : (Dict(str, t)) -> (str[])`"; + let arguments = (ins + Torch_DictType:$self + ); + let results = (outs + TorchStringListType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenGetDefaultStrOp : Torch_Op<"aten.get.default_str", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::get.default_str : (Dict(str, t), str, t) -> (t)`"; + let arguments = (ins + Torch_DictType:$self, + Torch_StringType:$key, + AnyTorchType:$default_value + ); + let results = (outs + AnyTorchType:$result + ); + let assemblyFormat = "$self `,` $key `,` $default_value attr-dict `:` type($self) `,` type($key) `,` type($default_value) `->` type($result)"; +} + +def Torch_AtenCatOp : Torch_Op<"aten.cat", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::cat : (Tensor[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorListType:$tensors, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$tensors `,` $dim attr-dict `:` type($tensors) `,` type($dim) `->` type($result)"; +} + +def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::append.t : (t[], t) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$self, + AnyTorchType:$el + ); + let results = (outs + AnyTorchListType:$result + ); + let assemblyFormat = "$self `,` $el attr-dict `:` type($self) `,` type($el) `->` type($result)"; +} + +def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::add.t : (t[], t[]) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$a, + AnyTorchListType:$b + ); + let results = (outs + AnyTorchListType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + +def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.int_list : (int[], int[]) -> (bool)`"; + let arguments = (ins + TorchIntListType:$a, + TorchIntListType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + +def Torch_AtenListTOp : Torch_Op<"aten.list.t", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::list.t : (t[]) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$l + ); + let results = (outs + AnyTorchListType:$result + ); + let assemblyFormat = "$l attr-dict `:` type($l) `->` type($result)"; +} + +def Torch_AtenSliceTOp : Torch_Op<"aten.slice.t", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::slice.t : (t[], int?, int?, int) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$l, + TorchOptionalIntType:$start, + TorchOptionalIntType:$end, + Torch_IntType:$step + ); + let results = (outs + AnyTorchListType:$result + ); + let assemblyFormat = "$l `,` $start `,` $end `,` $step attr-dict `:` type($l) `,` type($start) `,` type($end) `,` type($step) `->` type($result)"; +} + +def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::add.str : (str, str) -> (str)`"; + let arguments = (ins + Torch_StringType:$a, + Torch_StringType:$b + ); + let results = (outs + Torch_StringType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + +def Torch_AtenStrOp : Torch_Op<"aten.str", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::str : (t) -> (str)`"; + let arguments = (ins + AnyTorchType:$elem + ); + let results = (outs + Torch_StringType:$result + ); + let assemblyFormat = "$elem attr-dict `:` type($elem) `->` type($result)"; +} + +def Torch_AtenFormatOp : Torch_Op<"aten.format", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::format : (...) -> (str)`"; + let arguments = (ins + Variadic:$operands + ); + let results = (outs + Torch_StringType:$result + ); + let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands) `->` type($result)"; +} + +def Torch_AtenJoinOp : Torch_Op<"aten.join", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::join : (str, str[]) -> (str)`"; + let arguments = (ins + Torch_StringType:$self, + TorchStringListType:$values + ); + let results = (outs + Torch_StringType:$result + ); + let assemblyFormat = "$self `,` $values attr-dict `:` type($self) `,` type($values) `->` type($result)"; +} + +def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::Float.Scalar : (Scalar) -> (float)`"; + let arguments = (ins + AnyTorchScalarType:$a + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + +def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::Float.str : (str) -> (float)`"; + let arguments = (ins + Torch_StringType:$a + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + +def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::Int.float : (float) -> (int)`"; + let arguments = (ins + Torch_FloatType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + def Torch_AtenGtIntOp : Torch_Op<"aten.gt.int", [ AllowsTypeRefinement, HasValueSemantics @@ -409,6 +1912,54 @@ def Torch_AtenGtIntOp : Torch_Op<"aten.gt.int", [ let hasFolder = 1; } +def Torch_AtenGeIntOp : Torch_Op<"aten.ge.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ge.int : (int, int) -> (bool)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + +def Torch_AtenLtIntOp : Torch_Op<"aten.lt.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::lt.int : (int, int) -> (bool)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + +def Torch_AtenLeIntOp : Torch_Op<"aten.le.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::le.int : (int, int) -> (bool)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [ AllowsTypeRefinement, HasValueSemantics @@ -425,6 +1976,52 @@ def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [ let hasFolder = 1; } +def Torch_AtenEqIntOp : Torch_Op<"aten.eq.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::eq.int : (int, int) -> (bool)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + +def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::floordiv.int : (int, int) -> (int)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + +def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::remainder.int : (int, int) -> (int)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics @@ -440,6 +2037,21 @@ def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenSubIntOp : Torch_Op<"aten.sub.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::sub.int : (int, int) -> (int)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ AllowsTypeRefinement, HasValueSemantics @@ -455,6 +2067,20 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenLogIntOp : Torch_Op<"aten.log.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::log.int : (int) -> (float)`"; + let arguments = (ins + Torch_IntType:$a + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ AllowsTypeRefinement, HasValueSemantics @@ -485,6 +2111,20 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::neg.float : (float) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; +} + def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [ AllowsTypeRefinement, HasValueSemantics @@ -500,6 +2140,21 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::__and__.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; +} + def Torch_Aten__Is__Op : Torch_Op<"aten.__is__", [ AllowsTypeRefinement, HasValueSemantics @@ -516,6 +2171,37 @@ def Torch_Aten__Is__Op : Torch_Op<"aten.__is__", [ let hasFolder = 1; } +def Torch_Aten__Isnot__Op : Torch_Op<"aten.__isnot__", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::__isnot__ : (t1, t2) -> (bool)`"; + let arguments = (ins + AnyTorchType:$self, + AnyTorchType:$obj + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$self `,` $obj attr-dict `:` type($self) `,` type($obj) `->` type($result)"; + let hasFolder = 1; +} + +def Torch_Aten__Not__Op : Torch_Op<"aten.__not__", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::__not__ : (bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$self + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; + let hasFolder = 1; +} + def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ AllowsTypeRefinement, HasValueSemantics diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td index 291ca98ae..9b319143b 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td @@ -105,7 +105,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ ]> { let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`"; let arguments = (ins - AnyTorchIntListType:$self + TorchIntListType:$self ); let results = (outs Torch_IntType:$result @@ -134,7 +134,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ ]> { let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`"; let arguments = (ins - AnyTorchIntListType:$self + TorchIntListType:$self ); let results = (outs Torch_IntType:$result @@ -185,8 +185,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ DeclareOpInterfaceMethods, - AllowsTypeRefinement, - HasValueSemantics + AllowsTypeRefinement ]> { let summary = "Generated op for `prim::unchecked_cast : (t) -> (t)`"; let arguments = (ins @@ -210,3 +209,16 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", [ let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands)"; } +def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `prim::tolist : (...) -> (...)`"; + let arguments = (ins + Variadic:$operands + ); + let results = (outs + Variadic:$results + ); + let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands) `->` type($results)"; +} + diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.h b/include/npcomp/Dialect/Torch/IR/TorchOps.h index 7d7b423a4..5919b5958 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.h +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.h @@ -47,12 +47,36 @@ struct torch_constant_int_op_binder { }; } // namespace detail -/// Matches the integer stored in a `torch.constant.int`. +/// Matches the integer stored in a `torch.constant.bool`. inline detail::torch_constant_int_op_binder m_TorchConstantInt(int64_t *bind_value) { return detail::torch_constant_int_op_binder(bind_value); } +namespace detail { +/// Matches the bool stored in a `torch.constant.bool`. +struct torch_constant_bool_op_binder { + bool *bind_value; + + /// Creates a matcher instance that binds the value to bv if match succeeds. + torch_constant_bool_op_binder(bool *bv) : bind_value(bv) {} + + bool match(Operation *op) { + if (auto constantBool = dyn_cast(op)) { + *bind_value = constantBool.value(); + return true; + } + return false; + } +}; +} // namespace detail + +/// Matches the bool stored in a `torch.constant.bool`. +inline detail::torch_constant_bool_op_binder +m_TorchConstantBool(bool *bind_value) { + return detail::torch_constant_bool_op_binder(bind_value); +} + namespace detail { /// Matches the constant integers stored in a `torch.ListConstruct`. struct torch_list_construct_op_binder { diff --git a/include/npcomp/Dialect/Torch/IR/TorchTypes.td b/include/npcomp/Dialect/Torch/IR/TorchTypes.td index 9d8cfc003..cc8dc48e6 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchTypes.td +++ b/include/npcomp/Dialect/Torch/IR/TorchTypes.td @@ -375,27 +375,38 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> { // Type predicates //===----------------------------------------------------------------------===// -def AnyTorchOptionalTensor : AnyTypeOf<[ - AnyTorchTensorType, - Torch_OptionalType, - Torch_NoneType, -], "optional torch tensor">; +class OptionalOf : + AnyTypeOf<[type, Torch_OptionalType, Torch_NoneType], descr> ; + +def AnyTorchOptionalTensorType : + OptionalOf; +def TorchOptionalIntType: OptionalOf; +def TorchOptionalBoolType: + OptionalOf; +def TorchOptionalDeviceType: + OptionalOf; def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">; - class ListOf allowedTypes, string descr> : - ContainerType, IsListTypePred, + ContainerType, + IsListTypePred, "$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()", - descr, "::mlir::NPCOMP::Torch::ListType">; + descr, "::mlir::NPCOMP::Torch::ListType">; -def AnyTorchBoolListType : ListOf<[Torch_BoolType], "Any bool list type (bool[])">; - -def AnyTorchIntListType : ListOf<[Torch_IntType], "Any int list type (int[])">; +def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; +def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">; +def TorchStringListType : ListOf<[Torch_StringType], "Str list type (str[])">; +def AnyTorchTensorListType: + ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">; +def AnyTorchOptionalTensorListType : + ListOf<[AnyTorchOptionalTensorType], + "Any optional tensor list type (Tensor?[])">; def AnyTorchScalarType : AnyTypeOf<[ Torch_IntType, Torch_FloatType, Torch_BoolType, + Torch_NumberType, ], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">; // See function `DictTypePtr create(TypePtr key, TypePtr value)` diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index cd28a4a74..b5136a04e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -184,14 +184,13 @@ static LogicalResult verify(PrimDictConstructOp op) { }; Type keyType = op.getKeyType(); - if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) { + if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) return op.emitError() << "keys should be of Dict key type"; - } Type valueType = op.getValueType(); - if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) { + if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) return op.emitError() << "values should be of Dict value type"; - } + return success(); } @@ -367,21 +366,52 @@ void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +template +static OpFoldResult atenIsOrIsNotFoldHelper(OpTy op, bool equalIsTrue) { + Type lhsType = op.self().getType(); + Type rhsType = op.obj().getType(); + + // If either type is a NoneType, make it be the lhsType. + if (rhsType.template isa()) + std::swap(lhsType, rhsType); + // TODO: Implement and use subtype infra for this. + // If neither type is a subtype of the other, then the result is false. + if (lhsType.template isa() && + rhsType.template isa()) + return IntegerAttr::get(IntegerType::get(op.getContext(), 1), equalIsTrue); + + if (lhsType.template isa() && + !rhsType.template isa()) + return IntegerAttr::get(IntegerType::get(op.getContext(), 1), !equalIsTrue); + + return nullptr; +} + //===----------------------------------------------------------------------===// // Aten__Is__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Is__Op::fold(ArrayRef operands) { - auto lhsType = self().getType(); - auto rhsType = obj().getType(); - // If either type is a NoneType, make it be the lhsType. - if (rhsType.isa()) - std::swap(lhsType, rhsType); - // TODO: Implement and use subtype infra for this. - // If neither type is a subtype of the other, then the result is false. - if (lhsType.isa() && !rhsType.isa()) - return IntegerAttr::get(IntegerType::get(getContext(), 1), 0); - return nullptr; + return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); +} + +//===----------------------------------------------------------------------===// +// Aten__Isnot__Op +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Isnot__Op::fold(ArrayRef operands) { + return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); +} + +//===----------------------------------------------------------------------===// +// Aten__Not__Op +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Not__Op::fold(ArrayRef operands) { + bool value; + if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) + return nullptr; + return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); } //===----------------------------------------------------------------------===// @@ -465,14 +495,19 @@ static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) { static_cast(value)); } -OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { - auto lhs = operands[0].dyn_cast_or_null(); - auto rhs = operands[1].dyn_cast_or_null(); - if (lhs && rhs) { - return getI1IntegerAttr(getContext(), lhs.getValue().getSExtValue() > - rhs.getValue().getSExtValue()); - } - return nullptr; +using ConstantIntComparator = std::function; +template +static OpFoldResult comparatorFoldHelper(OpTy op, + ConstantIntComparator comparator) { + if (op.getOperand(0) == op.getOperand(1)) + return getI1IntegerAttr(op.getContext(), comparator(0, 0)); + + int64_t lhs, rhs; + if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) || + !matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs))) + return nullptr; + + return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs)); } //===----------------------------------------------------------------------===// @@ -480,16 +515,53 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult AtenNeIntOp::fold(ArrayRef operands) { - // `torch.aten.ne.int %x, %x` -> `false` - if (getOperand(0) == getOperand(1)) - return getI1IntegerAttr(getContext(), false); - auto lhs = operands[0].dyn_cast_or_null(); - auto rhs = operands[1].dyn_cast_or_null(); - if (lhs && rhs) { - return getI1IntegerAttr(getContext(), lhs.getValue().getSExtValue() != - rhs.getValue().getSExtValue()); - } - return nullptr; + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a != b; }); +} + +//===----------------------------------------------------------------------===// +// AtenEqIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqIntOp::fold(ArrayRef operands) { + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a == b; }); +} + +//===----------------------------------------------------------------------===// +// AtenLtIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenLtIntOp::fold(ArrayRef operands) { + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a < b; }); +} + +//===----------------------------------------------------------------------===// +// AtenLeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenLeIntOp::fold(ArrayRef operands) { + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a <= b; }); +} + +//===----------------------------------------------------------------------===// +// AtenGtIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a > b; }); +} + +//===----------------------------------------------------------------------===// +// AtenGeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGeIntOp::fold(ArrayRef operands) { + return comparatorFoldHelper(*this, + [](int64_t a, int64_t b) { return a >= b; }); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 0afc58bf2..7232155be 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -18,6 +18,8 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -73,8 +75,10 @@ public: ObjectGraphInfo(ModuleOp module) : globalSlotBuilder(module.getBodyRegion()), symbolTable(module) {} - LogicalResult initialize(NnModuleOp root) { - return recursivelyTraverse(root); + LogicalResult initialize(NnModuleOp rootNnModule) { + if (failed(collectUsedSlots())) + return failure(); + return recursivelyTraverse(rootNnModule); } LinkageInfo getSlotLinkageInfo(SlotOp op) { @@ -97,6 +101,51 @@ public: } private: + LogicalResult collectUsedSlots() { + // Collect all the slots in each module. + llvm::StringMap> moduleClassNameToSlots; + symbolTable.getOp()->walk([&](NnModuleOp moduleOp) { + llvm::StringMap nameToSlot; + for (auto attrOp : moduleOp.getOps()) + nameToSlot[attrOp.name()] = attrOp; + moduleClassNameToSlots[moduleOp.getClassName()] = nameToSlot; + }); + + // Find all the module slots that are accessed through `PrimGetAttrOp` or + // `PrimSetAttrOp`. + symbolTable.getOp()->walk([&](Operation *op) { + if (!isa(op)) + return; + + Value module; + StringRef slotName; + if (auto getAttrOp = llvm::dyn_cast(op)) { + module = getAttrOp.receiver(); + slotName = getAttrOp.name(); + } else { + auto setAttrOp = cast(op); + module = setAttrOp.receiver(); + slotName = setAttrOp.name(); + } + + auto moduleType = module.getType().cast(); + auto slots = moduleClassNameToSlots.find(moduleType.getClassName()); + // TODO: Improve verifier so that this can never happen + if (slots == moduleClassNameToSlots.end()) + op->emitError() << "Reference to non-existing module type " + << moduleType.getClassName(); + + llvm::StringMap nameToSlot = slots->getValue(); + auto slotIt = nameToSlot.find(slotName); + // TODO: Improve verifier so that this can never happen + if (slotIt == nameToSlot.end()) + op->emitError() << "Reference to non-existing module slot " << slotName + << "in " << moduleType.getClassName(); + usedSlots.insert(slotIt->getValue()); + }); + return success(); + } + LogicalResult recursivelyTraverse(NnModuleOp nnModule) { std::string pathToClassFromRoot = llvm::join(nameStack, "."); if (!seenNnModules.insert({nnModule, pathToClassFromRoot}).second) { @@ -127,7 +176,7 @@ private: assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end()); slotToGlobalSlot[slot] = globalSlot; slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()}; - if (failed(populateGlobalSlotInitializer(globalSlot, slot.value()))) + if (failed(populateGlobalSlotInitializer(globalSlot, slot))) return failure(); } nameStack.pop_back(); @@ -142,11 +191,12 @@ private: return success(); } LogicalResult populateGlobalSlotInitializer(GlobalSlotOp globalSlot, - Value initialValue) { + SlotOp slot) { OpBuilder builder(globalSlot.getContext()); builder.createBlock(&globalSlot.getRegion()); SmallPtrSet needToClone; + Value initialValue = slot.value(); SmallVector worklist = {initialValue.getDefiningOp()}; while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); @@ -167,6 +217,8 @@ private: for (Value result : op->getResults()) { if (!hasMeaningfulObjectIdentity(result.getType())) continue; + if (usedSlots.find(slot) == usedSlots.end()) + continue; if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result) .second) { return op->emitError() << "potentially-aliased value used to " @@ -205,6 +257,9 @@ private: // which cannot be used in multiple initializers because their object // identity is important. DenseSet objectsWithIdentityAlreadyCopiedIntoInitializers; + // Used to keep track of all the used torch slots so that the restrictions can + // be applied to those slots only. + DenseSet usedSlots; }; } // namespace diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir index 301484f84..24a4ae51a 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir @@ -39,3 +39,24 @@ torch.nn_module { torch.slot "f", %f : !torch.float torch.slot "t", %t : !torch.tensor } : !torch.nn.Module<"c"> + + +// ----- + +// CHECK-LABEL: torch.global_slot @t1 : !torch.tensor { +// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +// CHECK: torch.global_slot.init %[[T]] : !torch.tensor + +// CHECK-LABEL: torch.global_slot @t2 : !torch.tensor { +// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +// CHECK: torch.global_slot.init %[[T]] : !torch.tensor + +%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +torch.class_type @c { + torch.attr "t1" : !torch.tensor + torch.attr "t2" : !torch.tensor +} +torch.nn_module { + torch.slot "t1", %t : !torch.tensor + torch.slot "t2", %t : !torch.tensor +} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir index c6e5a5a5e..b0ed5edfe 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir @@ -42,3 +42,29 @@ torch.nn_module { torch.slot "t1", %t : !torch.tensor torch.slot "t2", %t : !torch.tensor } : !torch.nn.Module<"c"> +builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor { + %t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor + %t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor + %cst = torch.constant.int 1 + %ret = torch.aten.add.Tensor %t1, %t2, %cst : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor + return %ret : !torch.tensor +} + +// ----- + +torch.class_type @c { + torch.attr "t1" : !torch.tensor + torch.attr "t2" : !torch.tensor +} + +// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}} +%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +torch.nn_module { + torch.slot "t1", %t : !torch.tensor + torch.slot "t2", %t : !torch.tensor +} : !torch.nn.Module<"c"> +builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) { + torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor + torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor + return +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 88de03ed6..97c12025c 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -8,6 +8,30 @@ func @torch.aten.__is__(%arg0: !torch.list, %arg1: !torch.none) -> ! return %0 : !torch.bool } +// CHECK-LABEL: func @torch.aten.__is__$none_is_none +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.__is__$none_is_none(%arg0: !torch.none, %arg1: !torch.none) -> !torch.bool { + %0 = torch.aten.__is__ %arg0, %arg1 : !torch.none, !torch.none -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.__isnot__ +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.__isnot__(%arg0: !torch.list, %arg1: !torch.none) -> !torch.bool { + %0 = torch.aten.__isnot__ %arg0, %arg1 : !torch.list, !torch.none -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.__isnot__$none_isnot_none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torch.none) -> !torch.bool { + %0 = torch.aten.__isnot__ %arg0, %arg1 : !torch.none, !torch.none -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func @torch.aten.size$canonicalize_to_list( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list { // CHECK: %[[C2:.*]] = torch.constant.int 2 @@ -30,6 +54,122 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.l return %0 : !torch.list } +// CHECK-LABEL: func @torch.aten.ne.int$same_operand( +// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: return %[[FALSE]] : !torch.bool +func @torch.aten.ne.int$same_operand(%arg0: !torch.int) -> !torch.bool { + %0 = torch.aten.ne.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.int$same_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.ne.int$same_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int4_0 = torch.constant.int 4 + %2 = torch.aten.ne.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.int$different_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ne.int$different_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %2 = torch.aten.ne.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.int$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.eq.int$different_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %2 = torch.aten.eq.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.int$same_operand( +// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { +// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[F]] : !torch.bool +func @torch.aten.eq.int$same_operand(%arg0: !torch.int) -> !torch.bool { + %0 = torch.aten.eq.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.eq.int$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.eq.int$same_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int4_0 = torch.constant.int 4 + %2 = torch.aten.eq.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.lt.int$evaluate_to_true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.lt.int$evaluate_to_true() -> !torch.bool { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %2 = torch.aten.lt.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.lt.int$same_operand( +// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.lt.int$same_operand(%arg0: !torch.int) -> !torch.bool { + %2 = torch.aten.lt.int %arg0, %arg0: !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.lt.int$same_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.lt.int$same_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int4_0 = torch.constant.int 4 + %2 = torch.aten.lt.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.le.int$evaluate_to_true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.le.int$evaluate_to_true() -> !torch.bool { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %2 = torch.aten.le.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.le.int$same_operand( +// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.le.int$same_operand(%arg0: !torch.int) -> !torch.bool { + %2 = torch.aten.le.int %arg0, %arg0: !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.le.int$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.le.int$same_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int4_0 = torch.constant.int 4 + %2 = torch.aten.le.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func @torch.aten.gt.int$evaluate_to_true() -> !torch.bool { // CHECK-NEXT: %[[T:.*]] = torch.constant.bool true // CHECK-NEXT: return %[[T]] : !torch.bool @@ -50,35 +190,44 @@ func @torch.aten.gt.int$evaluate_to_false() -> !torch.bool { return %0 : !torch.bool } -// CHECK-LABEL: func @torch.aten.ne.int$same_operand( -// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false -// CHECK-NEXT: return %[[F]] : !torch.bool -func @torch.aten.ne.int$same_operand(%arg0: !torch.int) -> !torch.bool { - %0 = torch.aten.ne.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.bool - return %0 : !torch.bool -} - -// CHECK-LABEL: func @torch.aten.ne.int$same_value() -> !torch.bool { -// CHECK: %[[VAL_0:.*]] = torch.constant.bool false -// CHECK: return %[[VAL_0]] : !torch.bool -func @torch.aten.ne.int$same_value() -> !torch.bool { - %int4 = torch.constant.int 4 - %int4_0 = torch.constant.int 4 - %2 = torch.aten.ne.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool - return %2 : !torch.bool -} - -// CHECK-LABEL: func @torch.aten.ne.int$different_value() -> !torch.bool { -// CHECK: %[[VAL_0:.*]] = torch.constant.bool true -// CHECK: return %[[VAL_0]] : !torch.bool -func @torch.aten.ne.int$different_value() -> !torch.bool { +// CHECK-LABEL: func @torch.aten.ge.int$evaluate_to_false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.ge.int$evaluate_to_false() -> !torch.bool { %int4 = torch.constant.int 4 %int5 = torch.constant.int 5 - %2 = torch.aten.ne.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.ge.int %int4, %int5 : !torch.int, !torch.int -> !torch.bool return %2 : !torch.bool } +// CHECK-LABEL: func @torch.aten.ge.int$same_operand( +// CHECK-SAME: %{{.*}}: !torch.int) -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ge.int$same_operand(%arg0: !torch.int) -> !torch.bool { + %2 = torch.aten.ge.int %arg0, %arg0: !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ge.int$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ge.int$same_value() -> !torch.bool { + %int4 = torch.constant.int 4 + %int4_0 = torch.constant.int 4 + %2 = torch.aten.ge.int %int4, %int4_0 : !torch.int, !torch.int -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.__not__ +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.__not__() -> !torch.bool { + %false = torch.constant.bool false + %ret = torch.aten.__not__ %false : !torch.bool -> !torch.bool + return %ret: !torch.bool +} + // CHECK-LABEL: func @torch.aten.len.t$of_size( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int { // CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> !torch.int diff --git a/tools/run_lit.sh b/tools/run_lit.sh index 04b0cf1af..67aa4de76 100755 --- a/tools/run_lit.sh +++ b/tools/run_lit.sh @@ -6,8 +6,7 @@ set -e td="$(realpath $(dirname $0)/..)" build_dir="$td/build" -install_mlir="$td/install-mlir" -build_mlir="$td/external/llvm-project/build" +build_mlir="$build_dir/llvm" lit_exe="$build_mlir/bin/llvm-lit" if ! [ -f "$lit_exe" ]; then