Add linspace/cumprod/roll ops (#2498)

Add linspace/cumprod/roll ops to ODS and add shape inference functions
to make it work with LTC.

Also, add some tensor utils to LTC library for searching for non-detach
copy nodes.
pull/2502/head
Jae Hoon (Antonio) Kim 2023-10-03 11:01:07 -04:00 committed by GitHub
parent d10a86f51c
commit 32d9b20bde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 20 deletions

View File

@ -830,7 +830,6 @@ STABLEHLO_PASS_SET = {
"ReshapeAliasCollapseModule_basic", "ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic", "ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic", "ReshapeExpandModule_basic",
"RollModule_basic",
"TestMultipleTensorReturn_basic", "TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
@ -1355,7 +1354,6 @@ LTC_XFAIL_SET = {
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",
"QuantizedMLP_basic", "QuantizedMLP_basic",
"RollModule_basic",
"ScalarImplicitFloatModule_basic", "ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic", "ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic", "SliceEndSleStartModule_basic",

View File

@ -6440,6 +6440,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [
}]; }];
} }
def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenCumprodOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -10464,6 +10489,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
}]; }];
} }
def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchScalarType:$start,
AnyTorchScalarType:$end,
Torch_IntType:$steps,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenLinspaceOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
TorchMlirNode* TorchMlirNode::mlir_node(int index) { TorchMlirNode* TorchMlirNode::mlir_node(int index) const {
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get()); return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
} }

View File

@ -51,7 +51,7 @@ public:
hash_t shapeHash() const override; hash_t shapeHash() const override;
TorchMlirNode* mlir_node(int index); TorchMlirNode* mlir_node(int index) const;
virtual TorchMlirOpVector virtual TorchMlirOpVector
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;

View File

@ -386,5 +386,17 @@ std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})}; return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
} }
std::vector<torch::lazy::Shape> compute_shape_roll(
const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
auto out_meta =
at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
} // namespace lazy } // namespace lazy
} // namespace torch } // namespace torch

View File

@ -7,27 +7,65 @@
namespace torch { namespace torch {
namespace lazy { namespace lazy {
bool is_detach_copy(const torch::lazy::Node* node) {
return node && node->op() == torch::lazy::DetachCopy::ClassOpKind();
}
bool is_detach_copy(const torch::lazy::Value& value) { bool is_detach_copy(const torch::lazy::Value& value) {
return value->op() == torch::lazy::DetachCopy::ClassOpKind(); return is_detach_copy(value.node.get());
} }
torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) {
if (!node) { return nullptr; }
torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<torch::lazy::TorchMlirNode*>(node);
while(mlir_node && is_detach_copy(mlir_node)) {
mlir_node = mlir_node->mlir_node(0);
}
if (!mlir_node) {
return node;
}
return mlir_node;
}
const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) {
if (!node) { return nullptr; }
const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<const torch::lazy::TorchMlirNode*>(node);
while(mlir_node && is_detach_copy(mlir_node)) {
mlir_node = mlir_node->mlir_node(0);
}
if (!mlir_node) {
return node;
}
return mlir_node;
}
torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) {
if (!node) {
return nullptr;
}
node = extract_non_detach_copy_node(node);
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
return dynamic_cast<torch::lazy::DeviceData*>(node);
}
return nullptr;
}
const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) {
if (!node) {
return nullptr;
}
node = extract_non_detach_copy_node(node);
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
return dynamic_cast<const torch::lazy::DeviceData*>(node);
}
return nullptr;
}
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
if (!value) { if (!value) {
return nullptr; return nullptr;
} }
torch::lazy::TorchMlirNode* node = dynamic_cast<torch::lazy::TorchMlirNode*>(value.node.get()); return device_data_cast(value.node.get());
while(node) {
if (node->op() == torch::lazy::DeviceData::ClassOpKind()) {
return dynamic_cast<torch::lazy::DeviceData*>(node);
}
else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) {
node = node->mlir_node(0);
}
else {
break;
}
}
return nullptr;
} }
torch::lazy::DeviceData* device_data_cast( torch::lazy::DeviceData* device_data_cast(

View File

@ -8,10 +8,15 @@
namespace torch { namespace torch {
namespace lazy { namespace lazy {
TORCH_API bool is_detach_copy(const torch::lazy::Value& value); TORCH_API bool is_detach_copy(const torch::lazy::Node*);
TORCH_API bool is_detach_copy(const torch::lazy::Value&);
TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*);
TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*);
TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*);
TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*);
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
TORCH_API torch::lazy::DeviceData* device_data_cast( TORCH_API torch::lazy::DeviceData* device_data_cast(
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
); );

View File

@ -471,6 +471,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
@ -625,6 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
# Functionalization ops # Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias_copy : (Tensor) -> (Tensor)")