Add where, gt, bucketize and reshape ops to Torch dialect

This patch adds the where, gt, bucketize and reshape
ops to the Torch dialect. These ops are present in the histogram
calibration module.

TEST: Successfully lowers ops to Torch dialect in histogram module.
pull/473/head snapshot-20211210.136
harsh 2021-12-09 22:53:14 +00:00 committed by Sean Silva
parent cfc8de36f8
commit 03b6edce68
2 changed files with 101 additions and 17 deletions

View File

@ -512,6 +512,36 @@ def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
} }
def Torch_AtenGtTensorOp : Torch_Op<"aten.gt.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::gt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [ def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
@ -1071,22 +1101,6 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
} }
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
@ -1942,6 +1956,23 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)"; let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
} }
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$boundaries,
Torch_BoolType:$out_int32,
Torch_BoolType:$right
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` type($self) `,` type($boundaries) `,` type($out_int32) `,` type($right) `->` type($result)";
}
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {
@ -2002,6 +2033,25 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)"; let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)";
} }
def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
}
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
@ -2139,6 +2189,20 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)"; let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)";
} }
def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$shape
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $shape attr-dict `:` type($self) `,` type($shape) `->` type($result)";
}
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {
@ -2312,6 +2376,22 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)"; let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
} }
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -454,6 +454,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
@ -479,7 +480,6 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)") emit("aten::gelu : (Tensor) -> (Tensor)")
@ -550,10 +550,12 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
@ -563,6 +565,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)") emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)") emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
@ -574,6 +577,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::cpu : (Tensor) -> (Tensor)")