mirror of https://github.com/llvm/torch-mlir
Fix lowering of reduce ops
We were not filling the `outs` with the neutral element of the reduction, which resulted in reading uninitialized values (we were getting lucky that sometimes the uninitialized buffers were all zero's). Also, - Slight tweak to error messages in the e2e framework.pull/302/head
parent
6724de7692
commit
5f3eb637c4
|
@ -22,9 +22,10 @@ class TensorSummary:
|
|||
self.min = torch.min(tensor)
|
||||
self.max = torch.max(tensor)
|
||||
self.mean = torch.mean(tensor)
|
||||
self.shape = list(tensor.shape)
|
||||
|
||||
def __str__(self):
|
||||
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||
return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
|
||||
|
||||
|
||||
class ErrorContext:
|
||||
|
|
|
@ -116,7 +116,7 @@ class ErroneousModule(torch.nn.Module):
|
|||
|
||||
# CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch"
|
||||
# CHECK-NEXT: @ output of call to "test_tensor_value_mismatch"
|
||||
# CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000)
|
||||
# CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5)
|
||||
@torch.jit.export
|
||||
def test_tensor_value_mismatch(self):
|
||||
if torch.jit.is_scripting():
|
||||
|
|
|
@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
||||
Operation *op,
|
||||
Type elementType) {
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
||||
elementType.isa<mlir::FloatType>())
|
||||
return b.create<mlir::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgNeutralElementForReduceOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForReduceOp(
|
||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||
ArrayRef<Value> operands, Type elementType) {
|
||||
|
@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultShape, resultType.getElementType());
|
||||
Value initValue = createLinalgNeutralElementForReduceOp(
|
||||
rewriter, loc, op, resultType.getElementType());
|
||||
Value accumulator =
|
||||
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
|
||||
.getResult(0);
|
||||
bool hadErrorCreatingPayload = false;
|
||||
auto generic = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/initTensor.getType(),
|
||||
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||
/*inputs=*/tensorOperand,
|
||||
/*outputs=*/initTensor,
|
||||
/*outputs=*/accumulator,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
|
|
Loading…
Reference in New Issue