mirror of https://github.com/llvm/torch-mlir
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not required by the corresponding affine map. Fix the verifier to consider number of arguments for both.pull/3756/head
parent
b1413a6c7f
commit
617c1c76ce
|
@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
|||
}
|
||||
|
||||
LogicalResult BindSymbolicShapeOp::verify() {
|
||||
if (getShapeSymbols().empty())
|
||||
return emitOpError() << "requires non-empty shapeSymbols";
|
||||
if (getShapeSymbols().size() !=
|
||||
getShapeExpressions().getValue().getNumSymbols())
|
||||
return emitOpError()
|
||||
<< "requires equal number of shape symbol args and symbol args to "
|
||||
"the attached affine map, since they are 1:1 mapped";
|
||||
|
||||
for (auto symbol : getShapeSymbols()) {
|
||||
Operation *definingOp = symbol.getDefiningOp();
|
||||
|
|
|
@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
|
|||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
// expected-error @+1 {{op requires non-empty shapeSymbols}}
|
||||
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
|
||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verifier should not fail here since the op does not require shapeSymbols.
|
||||
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
||||
|
|
Loading…
Reference in New Issue