mirror of https://github.com/llvm/torch-mlir
Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767)
Reverted due to downstream model changes. Will reland with fixes post
integration.
This reverts commit 6e8c7bed4b
.
pull/3644/merge
parent
e9ed4af9ce
commit
53f7532e76
|
@ -1627,25 +1627,6 @@ public:
|
||||||
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
|
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
|
||||||
auto resultType = cast<RankedTensorType>(
|
auto resultType = cast<RankedTensorType>(
|
||||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
bool isScalarOp = resultType.getShape().size() == 0;
|
|
||||||
if (isScalarOp) {
|
|
||||||
// for elementwise ops that are actually rank0 scalar computations,
|
|
||||||
// perform the payload outside a linalg generic op.
|
|
||||||
SmallVector<Value> payloadArgs;
|
|
||||||
for (auto t : tensorOperands) {
|
|
||||||
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
|
|
||||||
}
|
|
||||||
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
|
|
||||||
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
|
|
||||||
if (!scalarResult)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Failed to create payload for scalar elementwise op");
|
|
||||||
Value rank0Result =
|
|
||||||
createInitTensor(rewriter, loc, ValueRange{},
|
|
||||||
resultType.getElementType(), scalarResult);
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
bool hadErrorCreatingPayload = false;
|
bool hadErrorCreatingPayload = false;
|
||||||
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||||
|
|
|
@ -4,11 +4,13 @@
|
||||||
// CHECK-LABEL: func.func @elementwise$unary(
|
// CHECK-LABEL: func.func @elementwise$unary(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
|
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
|
||||||
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
|
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
|
||||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
|
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
|
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
|
||||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
|
// CHECK: linalg.yield %[[TANH]] : f32
|
||||||
|
// CHECK: } -> tensor<f32>
|
||||||
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
Loading…
Reference in New Issue