mirror of https://github.com/llvm/torch-mlir
Add linspace/cumprod/roll ops (#2498)
Add linspace/cumprod/roll ops to ODS and add shape inference functions to make it work with LTC. Also, add some tensor utils to LTC library for searching for non-detach copy nodes.pull/2502/head
parent
d10a86f51c
commit
32d9b20bde
|
@ -830,7 +830,6 @@ STABLEHLO_PASS_SET = {
|
|||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
"RollModule_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
|
@ -1355,7 +1354,6 @@ LTC_XFAIL_SET = {
|
|||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"RollModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
|
|
|
@ -6440,6 +6440,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchOptionalIntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenCumprodOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -10464,6 +10489,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$start,
|
||||
AnyTorchScalarType:$end,
|
||||
Torch_IntType:$steps,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenLinspaceOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
|||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||
|
||||
|
||||
TorchMlirNode* TorchMlirNode::mlir_node(int index) {
|
||||
TorchMlirNode* TorchMlirNode::mlir_node(int index) const {
|
||||
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ public:
|
|||
|
||||
hash_t shapeHash() const override;
|
||||
|
||||
TorchMlirNode* mlir_node(int index);
|
||||
TorchMlirNode* mlir_node(int index) const;
|
||||
|
||||
virtual TorchMlirOpVector
|
||||
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
||||
|
|
|
@ -386,5 +386,17 @@ std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
|
|||
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_roll(
|
||||
const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
auto out_meta =
|
||||
at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -7,27 +7,65 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
bool is_detach_copy(const torch::lazy::Node* node) {
|
||||
return node && node->op() == torch::lazy::DetachCopy::ClassOpKind();
|
||||
}
|
||||
bool is_detach_copy(const torch::lazy::Value& value) {
|
||||
return value->op() == torch::lazy::DetachCopy::ClassOpKind();
|
||||
return is_detach_copy(value.node.get());
|
||||
}
|
||||
|
||||
torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) {
|
||||
if (!node) { return nullptr; }
|
||||
|
||||
torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<torch::lazy::TorchMlirNode*>(node);
|
||||
while(mlir_node && is_detach_copy(mlir_node)) {
|
||||
mlir_node = mlir_node->mlir_node(0);
|
||||
}
|
||||
if (!mlir_node) {
|
||||
return node;
|
||||
}
|
||||
return mlir_node;
|
||||
}
|
||||
|
||||
const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) {
|
||||
if (!node) { return nullptr; }
|
||||
|
||||
const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<const torch::lazy::TorchMlirNode*>(node);
|
||||
while(mlir_node && is_detach_copy(mlir_node)) {
|
||||
mlir_node = mlir_node->mlir_node(0);
|
||||
}
|
||||
if (!mlir_node) {
|
||||
return node;
|
||||
}
|
||||
return mlir_node;
|
||||
}
|
||||
|
||||
|
||||
torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) {
|
||||
if (!node) {
|
||||
return nullptr;
|
||||
}
|
||||
node = extract_non_detach_copy_node(node);
|
||||
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||
return dynamic_cast<torch::lazy::DeviceData*>(node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) {
|
||||
if (!node) {
|
||||
return nullptr;
|
||||
}
|
||||
node = extract_non_detach_copy_node(node);
|
||||
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||
return dynamic_cast<const torch::lazy::DeviceData*>(node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
torch::lazy::TorchMlirNode* node = dynamic_cast<torch::lazy::TorchMlirNode*>(value.node.get());
|
||||
while(node) {
|
||||
if (node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||
return dynamic_cast<torch::lazy::DeviceData*>(node);
|
||||
}
|
||||
else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) {
|
||||
node = node->mlir_node(0);
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return device_data_cast(value.node.get());
|
||||
}
|
||||
|
||||
torch::lazy::DeviceData* device_data_cast(
|
||||
|
|
|
@ -8,10 +8,15 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
TORCH_API bool is_detach_copy(const torch::lazy::Value& value);
|
||||
TORCH_API bool is_detach_copy(const torch::lazy::Node*);
|
||||
TORCH_API bool is_detach_copy(const torch::lazy::Value&);
|
||||
|
||||
TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*);
|
||||
TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*);
|
||||
|
||||
TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*);
|
||||
TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*);
|
||||
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
|
||||
|
||||
TORCH_API torch::lazy::DeviceData* device_data_cast(
|
||||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
|
||||
);
|
||||
|
|
|
@ -471,6 +471,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
|
||||
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
|
||||
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
||||
|
@ -625,6 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
||||
# Functionalization ops
|
||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue