mirror of https://github.com/llvm/torch-mlir
Fix lowering of reduce ops
We were not filling the `outs` with the neutral element of the reduction, which resulted in reading uninitialized values (we were getting lucky that sometimes the uninitialized buffers were all zero's). Also, - Slight tweak to error messages in the e2e framework.pull/302/head
parent
6724de7692
commit
5f3eb637c4
|
@ -22,9 +22,10 @@ class TensorSummary:
|
||||||
self.min = torch.min(tensor)
|
self.min = torch.min(tensor)
|
||||||
self.max = torch.max(tensor)
|
self.max = torch.max(tensor)
|
||||||
self.mean = torch.mean(tensor)
|
self.mean = torch.mean(tensor)
|
||||||
|
self.shape = list(tensor.shape)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
|
||||||
|
|
||||||
|
|
||||||
class ErrorContext:
|
class ErrorContext:
|
||||||
|
|
|
@ -116,7 +116,7 @@ class ErroneousModule(torch.nn.Module):
|
||||||
|
|
||||||
# CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch"
|
# CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch"
|
||||||
# CHECK-NEXT: @ output of call to "test_tensor_value_mismatch"
|
# CHECK-NEXT: @ output of call to "test_tensor_value_mismatch"
|
||||||
# CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000)
|
# CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5)
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def test_tensor_value_mismatch(self):
|
def test_tensor_value_mismatch(self):
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
|
|
|
@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
|
Operation *op,
|
||||||
|
Type elementType) {
|
||||||
|
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
||||||
|
elementType.isa<mlir::FloatType>())
|
||||||
|
return b.create<mlir::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
||||||
|
|
||||||
|
op->emitError("unimplemented lowering in "
|
||||||
|
"createLinalgNeutralElementForReduceOp");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForReduceOp(
|
static Value createLinalgPayloadCalculationForReduceOp(
|
||||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||||
ArrayRef<Value> operands, Type elementType) {
|
ArrayRef<Value> operands, Type elementType) {
|
||||||
|
@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern {
|
||||||
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, resultShape, resultType.getElementType());
|
loc, resultShape, resultType.getElementType());
|
||||||
|
Value initValue = createLinalgNeutralElementForReduceOp(
|
||||||
|
rewriter, loc, op, resultType.getElementType());
|
||||||
|
Value accumulator =
|
||||||
|
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
|
||||||
|
.getResult(0);
|
||||||
bool hadErrorCreatingPayload = false;
|
bool hadErrorCreatingPayload = false;
|
||||||
auto generic = rewriter.create<linalg::GenericOp>(
|
auto generic = rewriter.create<linalg::GenericOp>(
|
||||||
loc, /*resultTensorTypes=*/initTensor.getType(),
|
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||||
/*inputs=*/tensorOperand,
|
/*inputs=*/tensorOperand,
|
||||||
/*outputs=*/initTensor,
|
/*outputs=*/accumulator,
|
||||||
/*indexingMaps=*/indexingMaps,
|
/*indexingMaps=*/indexingMaps,
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
|
|
Loading…
Reference in New Issue