mirror of https://github.com/llvm/torch-mlir
Add where, gt, bucketize and reshape ops to Torch dialect
This patch adds the where, gt, bucketize and reshape ops to the Torch dialect. These ops are present in the histogram calibration module. TEST: Successfully lowers ops to Torch dialect in histogram module.pull/473/head snapshot-20211210.136
parent
cfc8de36f8
commit
03b6edce68
|
@ -512,6 +512,36 @@ def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [
|
|||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGtTensorOp : Torch_Op<"aten.gt.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1071,22 +1101,6 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
|||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$condition,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1942,6 +1956,23 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
|
|||
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$boundaries,
|
||||
Torch_BoolType:$out_int32,
|
||||
Torch_BoolType:$right
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` type($self) `,` type($boundaries) `,` type($out_int32) `,` type($right) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -2002,6 +2033,25 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
|
|||
let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
TorchOptionalBoolType:$pin_memory,
|
||||
TorchOptionalIntType:$memory_format
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -2139,6 +2189,20 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
|
|||
let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$shape
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $shape attr-dict `:` type($self) `,` type($shape) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -2312,6 +2376,22 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
|||
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$condition,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -454,6 +454,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
|
@ -479,7 +480,6 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
|
@ -550,10 +550,12 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
|
||||
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
||||
|
@ -563,6 +565,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::numel : (Tensor) -> (int)")
|
||||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
|
@ -574,6 +577,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::view : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue