Add correct type checking for tm_tensor.attention

pull/2182/head snapshot-20230527.851
George Petterson 2023-05-25 19:04:54 -04:00 committed by Prashant Kumar
parent 5223f990df
commit b9d29dc055
1 changed files with 4 additions and 10 deletions

View File

@ -92,18 +92,12 @@ LogicalResult AttentionOp::verify() {
Operation *op = getOperation(); Operation *op = getOperation();
ShapedType queryType = getQueryType(); ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType(); ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();
ShapedType outputType = getOutputType();
ArrayRef<int64_t> queryShape = queryType.getShape(); ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape(); ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape(); if (keyShape[0] != queryShape[0])
ArrayRef<int64_t> outputShape = outputType.getShape(); return op->emitOpError("query and key batch mismatch");
if (failed(verifyCompatibleShape(queryShape, keyShape))) if (keyShape[2] != queryShape[2])
return op->emitOpError("incompatible key shape"); return op->emitOpError("query and key head dimension mismatch");
if (failed(verifyCompatibleShape(queryShape, valueShape)))
return op->emitOpError("incompatible value shape");
if (failed(verifyCompatibleShape(queryShape, outputShape)))
return op->emitOpError("incompatible output shape");
return success(); return success();
} }