mirror of https://github.com/llvm/torch-mlir
Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767)
Reverted due to downstream model changes. Will reland with fixes post
integration.
This reverts commit 6e8c7bed4b
.
pull/3644/merge
parent
e9ed4af9ce
commit
53f7532e76
|
@ -1627,25 +1627,6 @@ public:
|
|||
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
bool isScalarOp = resultType.getShape().size() == 0;
|
||||
if (isScalarOp) {
|
||||
// for elementwise ops that are actually rank0 scalar computations,
|
||||
// perform the payload outside a linalg generic op.
|
||||
SmallVector<Value> payloadArgs;
|
||||
for (auto t : tensorOperands) {
|
||||
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
|
||||
}
|
||||
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
|
||||
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
|
||||
if (!scalarResult)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Failed to create payload for scalar elementwise op");
|
||||
Value rank0Result =
|
||||
createInitTensor(rewriter, loc, ValueRange{},
|
||||
resultType.getElementType(), scalarResult);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
|
||||
return success();
|
||||
}
|
||||
bool hadErrorCreatingPayload = false;
|
||||
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
// CHECK-LABEL: func.func @elementwise$unary(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
|
||||
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
|
||||
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
|
||||
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
|
||||
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
|
||||
// CHECK: linalg.yield %[[TANH]] : f32
|
||||
// CHECK: } -> tensor<f32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
||||
// CHECK: }
|
||||
|
|
Loading…
Reference in New Issue