[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)

The op can be valid with no attached shape symbols if they are not
required by the corresponding affine map. Fix the verifier to consider
number of arguments for both.
pull/3756/head
Prathamesh Tagore 2024-10-02 18:25:54 +05:30 committed by GitHub
parent b1413a6c7f
commit 617c1c76ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 3 deletions

View File

@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
}
LogicalResult BindSymbolicShapeOp::verify() {
if (getShapeSymbols().empty())
return emitOpError() << "requires non-empty shapeSymbols";
if (getShapeSymbols().size() !=
getShapeExpressions().getValue().getNumSymbols())
return emitOpError()
<< "requires equal number of shape symbol args and symbol args to "
"the attached affine map, since they are 1:1 mapped";
for (auto symbol : getShapeSymbols()) {
Operation *definingOp = symbol.getDefiningOp();

View File

@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
// expected-error @+1 {{op requires non-empty shapeSymbols}}
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}
// -----
// Verifier should not fail here since the op does not require shapeSymbols.
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}
// -----
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}