diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index f6f63697d..ba7ed76c8 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -92,18 +92,12 @@ LogicalResult AttentionOp::verify() { Operation *op = getOperation(); ShapedType queryType = getQueryType(); ShapedType keyType = getKeyType(); - ShapedType valueType = getValueType(); - ShapedType outputType = getOutputType(); ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); - ArrayRef valueShape = valueType.getShape(); - ArrayRef outputShape = outputType.getShape(); - if (failed(verifyCompatibleShape(queryShape, keyShape))) - return op->emitOpError("incompatible key shape"); - if (failed(verifyCompatibleShape(queryShape, valueShape))) - return op->emitOpError("incompatible value shape"); - if (failed(verifyCompatibleShape(queryShape, outputShape))) - return op->emitOpError("incompatible output shape"); + if (keyShape[0] != queryShape[0]) + return op->emitOpError("query and key batch mismatch"); + if (keyShape[2] != queryShape[2]) + return op->emitOpError("query and key head dimension mismatch"); return success(); }