mirror of https://github.com/llvm/torch-mlir
Add correct type checking for tm_tensor.attention
parent
5223f990df
commit
b9d29dc055
|
@ -92,18 +92,12 @@ LogicalResult AttentionOp::verify() {
|
||||||
Operation *op = getOperation();
|
Operation *op = getOperation();
|
||||||
ShapedType queryType = getQueryType();
|
ShapedType queryType = getQueryType();
|
||||||
ShapedType keyType = getKeyType();
|
ShapedType keyType = getKeyType();
|
||||||
ShapedType valueType = getValueType();
|
|
||||||
ShapedType outputType = getOutputType();
|
|
||||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||||
ArrayRef<int64_t> valueShape = valueType.getShape();
|
if (keyShape[0] != queryShape[0])
|
||||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
return op->emitOpError("query and key batch mismatch");
|
||||||
if (failed(verifyCompatibleShape(queryShape, keyShape)))
|
if (keyShape[2] != queryShape[2])
|
||||||
return op->emitOpError("incompatible key shape");
|
return op->emitOpError("query and key head dimension mismatch");
|
||||||
if (failed(verifyCompatibleShape(queryShape, valueShape)))
|
|
||||||
return op->emitOpError("incompatible value shape");
|
|
||||||
if (failed(verifyCompatibleShape(queryShape, outputShape)))
|
|
||||||
return op->emitOpError("incompatible output shape");
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue