Added TorchToLinalg conversion for Aten_LocalScalarDenseOp

pull/2925/head
Max Dawkins 2024-02-15 14:25:49 -05:00
parent f8332a4da2
commit 2232780216
2 changed files with 5 additions and 2 deletions

View File

@ -10324,7 +10324,7 @@ def Torch_Aten_LocalScalarDenseOp : Torch_Op<"aten._local_scalar_dense", [
]> {
let summary = "Generated op for `aten::_local_scalar_dense : (Tensor) -> (Scalar)`";
let arguments = (ins
AnyTorchTensorType:$self
AnyTorchTensorType:$a
);
let results = (outs
AnyTorchScalarType:$result

View File

@ -213,13 +213,16 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenNumelOp>();
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp>();
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp,
Aten_LocalScalarDenseOp>();
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(typeConverter,
context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>(
typeConverter, context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(typeConverter,
context);
patterns.add<ConvertAtenTensorToScalarLikeOp<Aten_LocalScalarDenseOp>>(
typeConverter, context);
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();