diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 15cf9bb79..e46459f1a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -830,7 +830,6 @@ STABLEHLO_PASS_SET = { "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", - "RollModule_basic", "TestMultipleTensorReturn_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", @@ -1355,7 +1354,6 @@ LTC_XFAIL_SET = { "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", - "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4ecad92c6..f1338142d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6440,6 +6440,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ }]; } +def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenCumprodOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -10464,6 +10489,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ }]; } +def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + Torch_IntType:$steps, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenLinspaceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index e4b75e5d5..39dc1ad0c 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } -TorchMlirNode* TorchMlirNode::mlir_node(int index) { +TorchMlirNode* TorchMlirNode::mlir_node(int index) const { return dynamic_cast(operands_.at(index).get()); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index dbf3117db..4b5e196be 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -51,7 +51,7 @@ public: hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index); + TorchMlirNode* mlir_node(int index) const; virtual TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 043094c67..d5458f9c4 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -386,5 +386,17 @@ std::vector compute_shape_scalar_tensor( return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } +std::vector compute_shape_roll( + const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { + auto out_meta = + at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + + } // namespace lazy } // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp index 7131e9a66..cdd971680 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -7,27 +7,65 @@ namespace torch { namespace lazy { +bool is_detach_copy(const torch::lazy::Node* node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +} bool is_detach_copy(const torch::lazy::Value& value) { - return value->op() == torch::lazy::DetachCopy::ClassOpKind(); + return is_detach_copy(value.node.get()); } +torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { + if (!node) { return nullptr; } + + torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + +const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { + if (!node) { return nullptr; } + + const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + + +torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} +const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { if (!value) { return nullptr; } - torch::lazy::TorchMlirNode* node = dynamic_cast(value.node.get()); - while(node) { - if (node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } - else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) { - node = node->mlir_node(0); - } - else { - break; - } - } - return nullptr; + return device_data_cast(value.node.get()); } torch::lazy::DeviceData* device_data_cast( diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h index 717173e9a..745be78c3 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,10 +8,15 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Value& value); +TORCH_API bool is_detach_copy(const torch::lazy::Node*); +TORCH_API bool is_detach_copy(const torch::lazy::Value&); +TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); +TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); + +TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); +TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); - TORCH_API torch::lazy::DeviceData* device_data_cast( const at::Tensor& tensor, c10::optional device = c10::nullopt ); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 56d18d384..f540a1ad2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -471,6 +471,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") + emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") @@ -625,6 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") + emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)")