mirror of https://github.com/llvm/torch-mlir
Add correct type checking for tm_tensor.attention
parent
5223f990df
commit
b9d29dc055
|
@ -92,18 +92,12 @@ LogicalResult AttentionOp::verify() {
|
|||
Operation *op = getOperation();
|
||||
ShapedType queryType = getQueryType();
|
||||
ShapedType keyType = getKeyType();
|
||||
ShapedType valueType = getValueType();
|
||||
ShapedType outputType = getOutputType();
|
||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||
ArrayRef<int64_t> valueShape = valueType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
if (failed(verifyCompatibleShape(queryShape, keyShape)))
|
||||
return op->emitOpError("incompatible key shape");
|
||||
if (failed(verifyCompatibleShape(queryShape, valueShape)))
|
||||
return op->emitOpError("incompatible value shape");
|
||||
if (failed(verifyCompatibleShape(queryShape, outputShape)))
|
||||
return op->emitOpError("incompatible output shape");
|
||||
if (keyShape[0] != queryShape[0])
|
||||
return op->emitOpError("query and key batch mismatch");
|
||||
if (keyShape[2] != queryShape[2])
|
||||
return op->emitOpError("query and key head dimension mismatch");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue