Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767)

Reverted due to downstream model changes. Will reland with fixes post
integration.

This reverts commit 6e8c7bed4b.
pull/3644/merge
Rob Suderman 2024-10-04 14:48:02 -07:00 committed by GitHub
parent e9ed4af9ce
commit 53f7532e76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 24 deletions

View File

@ -1627,25 +1627,6 @@ public:
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); })); operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>( auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType())); getTypeConverter()->convertType(op->getResult(0).getType()));
bool isScalarOp = resultType.getShape().size() == 0;
if (isScalarOp) {
// for elementwise ops that are actually rank0 scalar computations,
// perform the payload outside a linalg generic op.
SmallVector<Value> payloadArgs;
for (auto t : tensorOperands) {
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
}
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
if (!scalarResult)
return rewriter.notifyMatchFailure(
op, "Failed to create payload for scalar elementwise op");
Value rank0Result =
createInitTensor(rewriter, loc, ValueRange{},
resultType.getElementType(), scalarResult);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
return success();
}
bool hadErrorCreatingPayload = false; bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric( Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(), rewriter, loc, tensorOperands, resultType.getElementType(),

View File

@ -4,11 +4,13 @@
// CHECK-LABEL: func.func @elementwise$unary( // CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32> // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32> // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32> // CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32> // CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: } // CHECK: }