mirror of https://github.com/llvm/torch-mlir
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not required by the corresponding affine map. Fix the verifier to consider number of arguments for both.pull/3756/head
parent
b1413a6c7f
commit
617c1c76ce
|
@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult BindSymbolicShapeOp::verify() {
|
LogicalResult BindSymbolicShapeOp::verify() {
|
||||||
if (getShapeSymbols().empty())
|
if (getShapeSymbols().size() !=
|
||||||
return emitOpError() << "requires non-empty shapeSymbols";
|
getShapeExpressions().getValue().getNumSymbols())
|
||||||
|
return emitOpError()
|
||||||
|
<< "requires equal number of shape symbol args and symbol args to "
|
||||||
|
"the attached affine map, since they are 1:1 mapped";
|
||||||
|
|
||||||
for (auto symbol : getShapeSymbols()) {
|
for (auto symbol : getShapeSymbols()) {
|
||||||
Operation *definingOp = symbol.getDefiningOp();
|
Operation *definingOp = symbol.getDefiningOp();
|
||||||
|
|
|
@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
|
||||||
|
|
||||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||||
// expected-error @+1 {{op requires non-empty shapeSymbols}}
|
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
|
||||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||||
return %arg0 : !torch.vtensor<[?],f32>
|
return %arg0 : !torch.vtensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Verifier should not fail here since the op does not require shapeSymbols.
|
||||||
|
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
|
||||||
|
return %arg0 : !torch.vtensor<[?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
||||||
|
|
Loading…
Reference in New Issue