mirror of https://github.com/llvm/torch-mlir
Add linspace/cumprod/roll ops (#2498)
Add linspace/cumprod/roll ops to ODS and add shape inference functions to make it work with LTC. Also, add some tensor utils to LTC library for searching for non-detach copy nodes.pull/2502/head
parent
d10a86f51c
commit
32d9b20bde
|
@ -830,7 +830,6 @@ STABLEHLO_PASS_SET = {
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
"RollModule_basic",
|
|
||||||
"TestMultipleTensorReturn_basic",
|
"TestMultipleTensorReturn_basic",
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
|
@ -1355,7 +1354,6 @@ LTC_XFAIL_SET = {
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"RollModule_basic",
|
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
"SliceEndSleStartModule_basic",
|
"SliceEndSleStartModule_basic",
|
||||||
|
|
|
@ -6440,6 +6440,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchOptionalIntType:$dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenCumprodOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
|
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -10464,6 +10489,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchScalarType:$start,
|
||||||
|
AnyTorchScalarType:$end,
|
||||||
|
Torch_IntType:$steps,
|
||||||
|
AnyTorchOptionalIntType:$dtype,
|
||||||
|
AnyTorchOptionalIntType:$layout,
|
||||||
|
AnyTorchOptionalDeviceType:$device,
|
||||||
|
AnyTorchOptionalBoolType:$pin_memory
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||||
|
}
|
||||||
|
void AtenLinspaceOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 7, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
||||||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||||
|
|
||||||
|
|
||||||
TorchMlirNode* TorchMlirNode::mlir_node(int index) {
|
TorchMlirNode* TorchMlirNode::mlir_node(int index) const {
|
||||||
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
|
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ public:
|
||||||
|
|
||||||
hash_t shapeHash() const override;
|
hash_t shapeHash() const override;
|
||||||
|
|
||||||
TorchMlirNode* mlir_node(int index);
|
TorchMlirNode* mlir_node(int index) const;
|
||||||
|
|
||||||
virtual TorchMlirOpVector
|
virtual TorchMlirOpVector
|
||||||
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
||||||
|
|
|
@ -386,5 +386,17 @@ std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
|
||||||
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_roll(
|
||||||
|
const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) {
|
||||||
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
|
auto out_meta =
|
||||||
|
at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||||
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
|
@ -7,27 +7,65 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
|
bool is_detach_copy(const torch::lazy::Node* node) {
|
||||||
|
return node && node->op() == torch::lazy::DetachCopy::ClassOpKind();
|
||||||
|
}
|
||||||
bool is_detach_copy(const torch::lazy::Value& value) {
|
bool is_detach_copy(const torch::lazy::Value& value) {
|
||||||
return value->op() == torch::lazy::DetachCopy::ClassOpKind();
|
return is_detach_copy(value.node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) {
|
||||||
|
if (!node) { return nullptr; }
|
||||||
|
|
||||||
|
torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<torch::lazy::TorchMlirNode*>(node);
|
||||||
|
while(mlir_node && is_detach_copy(mlir_node)) {
|
||||||
|
mlir_node = mlir_node->mlir_node(0);
|
||||||
|
}
|
||||||
|
if (!mlir_node) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return mlir_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) {
|
||||||
|
if (!node) { return nullptr; }
|
||||||
|
|
||||||
|
const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<const torch::lazy::TorchMlirNode*>(node);
|
||||||
|
while(mlir_node && is_detach_copy(mlir_node)) {
|
||||||
|
mlir_node = mlir_node->mlir_node(0);
|
||||||
|
}
|
||||||
|
if (!mlir_node) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return mlir_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) {
|
||||||
|
if (!node) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
node = extract_non_detach_copy_node(node);
|
||||||
|
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||||
|
return dynamic_cast<torch::lazy::DeviceData*>(node);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) {
|
||||||
|
if (!node) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
node = extract_non_detach_copy_node(node);
|
||||||
|
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||||
|
return dynamic_cast<const torch::lazy::DeviceData*>(node);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
|
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
torch::lazy::TorchMlirNode* node = dynamic_cast<torch::lazy::TorchMlirNode*>(value.node.get());
|
return device_data_cast(value.node.get());
|
||||||
while(node) {
|
|
||||||
if (node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
|
||||||
return dynamic_cast<torch::lazy::DeviceData*>(node);
|
|
||||||
}
|
|
||||||
else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) {
|
|
||||||
node = node->mlir_node(0);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::lazy::DeviceData* device_data_cast(
|
torch::lazy::DeviceData* device_data_cast(
|
||||||
|
|
|
@ -8,10 +8,15 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
TORCH_API bool is_detach_copy(const torch::lazy::Value& value);
|
TORCH_API bool is_detach_copy(const torch::lazy::Node*);
|
||||||
|
TORCH_API bool is_detach_copy(const torch::lazy::Value&);
|
||||||
|
|
||||||
|
TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*);
|
||||||
|
TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*);
|
||||||
|
|
||||||
|
TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*);
|
||||||
|
TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*);
|
||||||
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
|
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
|
||||||
|
|
||||||
TORCH_API torch::lazy::DeviceData* device_data_cast(
|
TORCH_API torch::lazy::DeviceData* device_data_cast(
|
||||||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
|
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
|
||||||
);
|
);
|
||||||
|
|
|
@ -471,6 +471,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
|
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
|
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
|
||||||
|
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
|
||||||
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
||||||
|
@ -625,6 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
||||||
|
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
|
|
||||||
# Functionalization ops
|
# Functionalization ops
|
||||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue