mirror of https://github.com/llvm/torch-mlir
Add verification for torch permute op (#2551)
- adds support for an optional verifier to the generated torch op tablegen (GeneratedTorchOps.td) - uses the above to add a verifier for the torch permute op. Motivation: I hit an unclear error from linalg while developing a decomposition pass for pixel_shuffle. The error would have been clearer if the problem had been detected earlier in the invalid aten.permute op. Testing: new tests added. To run added tests, from the base directory run ``` ./build/bin/llvm-lit test/Dialect/Torch/invalid.mlir ```pull/2563/head snapshot-20231116.1024
parent
e81282ae8f
commit
dad1f012f6
|
@ -6422,29 +6422,6 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dims
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenPermuteOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -6469,6 +6446,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dims
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenPermuteOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -2859,6 +2859,96 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AtenPermuteOp::verify() {
|
||||
|
||||
// Verification of the permute op for input & output dimensions with
|
||||
// statically known sizes.
|
||||
|
||||
SmallVector<Value> permutation;
|
||||
auto permutationObtained = getListConstructElements(getDims(), permutation);
|
||||
if (!permutationObtained) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto outType = getResult().getType().cast<BaseTensorType>();
|
||||
auto inType = getSelf().getType().cast<BaseTensorType>();
|
||||
|
||||
if (!outType.hasSizes() || !inType.hasSizes()) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto outShape = outType.getSizes();
|
||||
auto inShape = inType.getSizes();
|
||||
|
||||
auto outRank = outShape.size();
|
||||
|
||||
if (outRank != inShape.size()) {
|
||||
return emitOpError(
|
||||
"expected input and output tensors to have same rank, but ")
|
||||
<< inShape.size() << " != " << outRank << '.';
|
||||
}
|
||||
|
||||
if (outRank != permutation.size()) {
|
||||
return emitOpError() << "expected permutation to have size equal result "
|
||||
"tensor rank. The permutation has "
|
||||
<< permutation.size()
|
||||
<< " elements, the output has rank " << outRank << '.';
|
||||
}
|
||||
|
||||
|
||||
// Initialization of the reverse permutation. -1 denotes an unknown
|
||||
// permutation index.
|
||||
SmallVector<int64_t> reversePermutation(outRank, -1);
|
||||
|
||||
// In this loop:
|
||||
// (1) check that the permutation indices are in bounds, and not duplicated.
|
||||
// (2) populate reversePermutation (to check for duplicates).
|
||||
// (3) check that the input and output shapes agree with the permutation. For
|
||||
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
|
||||
// then the output shape must be (3,5,2).
|
||||
|
||||
for (uint64_t to = 0; to < outRank; ++to) {
|
||||
int64_t from;
|
||||
|
||||
auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));
|
||||
|
||||
if (!fromIsSet) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if 'from' is the unkwown index, continue.
|
||||
if (from == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isValidDim(from, outRank)) {
|
||||
return emitError("observed invalid index in permutation (")
|
||||
<< from << ") for input tensor of rank " << outRank << '.';
|
||||
}
|
||||
|
||||
if (reversePermutation[from] != -1) {
|
||||
return emitOpError("has a duplicate dimension (")
|
||||
<< from << ") in its permutation " << getDims() << '.';
|
||||
}
|
||||
reversePermutation[from] = to;
|
||||
|
||||
auto dimSizesDefined =
|
||||
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
|
||||
auto dimSizesDifferent = inShape[from] != outShape[to];
|
||||
|
||||
if (dimSizesDefined && dimSizesDifferent) {
|
||||
return emitOpError("has a permutation which is not compatible with the "
|
||||
"input and output shapes. ")
|
||||
<< "The input shape in dimension " << from << " is "
|
||||
<< inShape[from] << ", and the output shape in dimension " << to
|
||||
<< " is " << outShape[to]
|
||||
<< " : they should be the same with this permutation. ";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DtypeCalculateYieldDtypesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -114,7 +114,7 @@ ODS_BANNER = f"""//===-------------------------------------------------------*-
|
|||
def raw_emit_op(operator: JitOperator,
|
||||
emitter_td: TextEmitter,
|
||||
*, traits: List[str],
|
||||
has_folder: bool, has_canonicalizer: bool):
|
||||
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
|
||||
"""Emit the ODS for a JitOperator to a textual file.
|
||||
|
||||
This is the lowest level of emission and is responsible for low-level
|
||||
|
@ -199,6 +199,8 @@ def raw_emit_op(operator: JitOperator,
|
|||
p_td("let hasFolder = 1;")
|
||||
if has_canonicalizer:
|
||||
p_td("let hasCanonicalizer = 1;")
|
||||
if has_verifier:
|
||||
p_td("let hasVerifier = 1;")
|
||||
p_td("}")
|
||||
p_td("\n")
|
||||
|
||||
|
@ -208,7 +210,8 @@ def emit_op(operator: JitOperator,
|
|||
*,
|
||||
traits: Optional[List[str]] = None,
|
||||
has_folder: bool = False,
|
||||
has_canonicalizer: bool = False):
|
||||
has_canonicalizer: bool = False,
|
||||
has_verifier: bool = False):
|
||||
"""Main entry point for op emission.
|
||||
|
||||
Besides emitting the op, it deduces / adds traits based on the operator
|
||||
|
@ -228,7 +231,8 @@ def emit_op(operator: JitOperator,
|
|||
emitter_td,
|
||||
traits=traits,
|
||||
has_folder=has_folder,
|
||||
has_canonicalizer=has_canonicalizer)
|
||||
has_canonicalizer=has_canonicalizer,
|
||||
has_verifier=has_verifier)
|
||||
|
||||
|
||||
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||
|
@ -481,8 +485,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
|
||||
|
|
|
@ -281,3 +281,84 @@ func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,
|
|||
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
|
||||
return %0 : !torch.vtensor<*,f64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.permute$test_changing_rank (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
||||
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
|
||||
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
|
||||
// expected-error@+1 {{expected input and output tensors to have same rank, but 3 != 4}}
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
|
||||
|
||||
return %3 : !torch.vtensor<[1,2,3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.permute$test_permutation_too_short (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
|
||||
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
||||
%perm = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
||||
// expected-error@+1 {{The permutation has 2 elements, the output has rank 3}}
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
|
||||
|
||||
return %3 : !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.permute$duplicate_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[2,3,1],f32> {
|
||||
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%perm = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
|
||||
// expected-error@+1 {{'torch.aten.permute' op has a duplicate dimension (1) in its permutation}}
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1],f32>
|
||||
|
||||
return %3 : !torch.vtensor<[2,3,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.permute$incorrect_output_shape (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[3,1,2],f32> {
|
||||
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%none = torch.constant.none
|
||||
|
||||
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
|
||||
// expected-error@+1 {{'torch.aten.permute' op has a permutation which is not compatible with the input and output shapes. The input shape in dimension 1 is 2, and the output shape in dimension 0 is 3 : they should be the same with this permutation.}}
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[3,1,2],f32>
|
||||
|
||||
return %3 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
|
||||
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int7 = torch.constant.int 7
|
||||
%perm = torch.prim.ListConstruct %int0, %int1, %int7 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
|
||||
|
||||
// expected-error@+1 {{observed invalid index in permutation (7) for input tensor of rank 3.}}
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
|
||||
|
||||
return %3 : !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
|
||||
|
|
|
@ -170,3 +170,14 @@ func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,5
|
|||
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>
|
||||
return %arg2 : !torch.list<vtensor<[1,?,56,96],f16>>
|
||||
}
|
||||
|
||||
// Check that verification passes with '-1' as a permutation index.
|
||||
func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
|
||||
%intm1 = torch.constant.int -1
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
|
||||
return %3 : !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue