Add verification for torch permute op (#2551)

- adds support for an optional verifier to the generated torch op
tablegen (GeneratedTorchOps.td)
- uses the above to add a verifier for the torch permute op. 

Motivation: I hit an unclear error from linalg while developing a
decomposition pass for pixel_shuffle. The error would have been clearer
if the problem had been detected earlier in the invalid aten.permute op.

Testing: new tests added. To run added tests, from the base directory
run

```
 ./build/bin/llvm-lit  test/Dialect/Torch/invalid.mlir
 ```
pull/2563/head snapshot-20231116.1024
James Newling 2023-11-15 11:47:54 -08:00 committed by GitHub
parent e81282ae8f
commit dad1f012f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 27 deletions

View File

@ -6422,29 +6422,6 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
}];
}
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPermuteOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
@ -6469,6 +6446,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
}];
}
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPermuteOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}
def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [
AllowsTypeRefinement,
ReadOnly

View File

@ -2859,6 +2859,96 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}
LogicalResult AtenPermuteOp::verify() {
// Verification of the permute op for input & output dimensions with
// statically known sizes.
SmallVector<Value> permutation;
auto permutationObtained = getListConstructElements(getDims(), permutation);
if (!permutationObtained) {
return success();
}
auto outType = getResult().getType().cast<BaseTensorType>();
auto inType = getSelf().getType().cast<BaseTensorType>();
if (!outType.hasSizes() || !inType.hasSizes()) {
return success();
}
auto outShape = outType.getSizes();
auto inShape = inType.getSizes();
auto outRank = outShape.size();
if (outRank != inShape.size()) {
return emitOpError(
"expected input and output tensors to have same rank, but ")
<< inShape.size() << " != " << outRank << '.';
}
if (outRank != permutation.size()) {
return emitOpError() << "expected permutation to have size equal result "
"tensor rank. The permutation has "
<< permutation.size()
<< " elements, the output has rank " << outRank << '.';
}
// Initialization of the reverse permutation. -1 denotes an unknown
// permutation index.
SmallVector<int64_t> reversePermutation(outRank, -1);
// In this loop:
// (1) check that the permutation indices are in bounds, and not duplicated.
// (2) populate reversePermutation (to check for duplicates).
// (3) check that the input and output shapes agree with the permutation. For
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
// then the output shape must be (3,5,2).
for (uint64_t to = 0; to < outRank; ++to) {
int64_t from;
auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));
if (!fromIsSet) {
continue;
}
// if 'from' is the unkwown index, continue.
if (from == -1) {
continue;
}
if (!isValidDim(from, outRank)) {
return emitError("observed invalid index in permutation (")
<< from << ") for input tensor of rank " << outRank << '.';
}
if (reversePermutation[from] != -1) {
return emitOpError("has a duplicate dimension (")
<< from << ") in its permutation " << getDims() << '.';
}
reversePermutation[from] = to;
auto dimSizesDefined =
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
auto dimSizesDifferent = inShape[from] != outShape[to];
if (dimSizesDefined && dimSizesDifferent) {
return emitOpError("has a permutation which is not compatible with the "
"input and output shapes. ")
<< "The input shape in dimension " << from << " is "
<< inShape[from] << ", and the output shape in dimension " << to
<< " is " << outShape[to]
<< " : they should be the same with this permutation. ";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// DtypeCalculateYieldDtypesOp
//===----------------------------------------------------------------------===//

View File

@ -114,7 +114,7 @@ ODS_BANNER = f"""//===-------------------------------------------------------*-
def raw_emit_op(operator: JitOperator,
emitter_td: TextEmitter,
*, traits: List[str],
has_folder: bool, has_canonicalizer: bool):
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
"""Emit the ODS for a JitOperator to a textual file.
This is the lowest level of emission and is responsible for low-level
@ -199,6 +199,8 @@ def raw_emit_op(operator: JitOperator,
p_td("let hasFolder = 1;")
if has_canonicalizer:
p_td("let hasCanonicalizer = 1;")
if has_verifier:
p_td("let hasVerifier = 1;")
p_td("}")
p_td("\n")
@ -208,7 +210,8 @@ def emit_op(operator: JitOperator,
*,
traits: Optional[List[str]] = None,
has_folder: bool = False,
has_canonicalizer: bool = False):
has_canonicalizer: bool = False,
has_verifier: bool = False):
"""Main entry point for op emission.
Besides emitting the op, it deduces / adds traits based on the operator
@ -228,7 +231,8 @@ def emit_op(operator: JitOperator,
emitter_td,
traits=traits,
has_folder=has_folder,
has_canonicalizer=has_canonicalizer)
has_canonicalizer=has_canonicalizer,
has_verifier=has_verifier)
def emit_ops(emitter_td: TextEmitter, registry: Registry):
@ -481,8 +485,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")

View File

@ -281,3 +281,84 @@ func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
return %0 : !torch.vtensor<*,f64>
}
// -----
func.func @torch.permute$test_changing_rank (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{expected input and output tensors to have same rank, but 3 != 4}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
return %3 : !torch.vtensor<[1,2,3,4],f32>
}
// -----
func.func @torch.permute$test_permutation_too_short (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%perm = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{The permutation has 2 elements, the output has rank 3}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
return %3 : !torch.vtensor<[1,2,3],f32>
}
// -----
func.func @torch.permute$duplicate_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[2,3,1],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%perm = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{'torch.aten.permute' op has a duplicate dimension (1) in its permutation}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1],f32>
return %3 : !torch.vtensor<[2,3,1],f32>
}
// -----
func.func @torch.permute$incorrect_output_shape (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[3,1,2],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%none = torch.constant.none
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{'torch.aten.permute' op has a permutation which is not compatible with the input and output shapes. The input shape in dimension 1 is 2, and the output shape in dimension 0 is 3 : they should be the same with this permutation.}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[3,1,2],f32>
return %3 : !torch.vtensor<[3,1,2],f32>
}
// -----
func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int7 = torch.constant.int 7
%perm = torch.prim.ListConstruct %int0, %int1, %int7 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{observed invalid index in permutation (7) for input tensor of rank 3.}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
return %3 : !torch.vtensor<[1,2,3],f32>
}

View File

@ -170,3 +170,14 @@ func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,5
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>
return %arg2 : !torch.list<vtensor<[1,?,56,96],f16>>
}
// Check that verification passes with '-1' as a permutation index.
func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
%intm1 = torch.constant.int -1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
return %3 : !torch.vtensor<[1,2,3],f32>
}