Added TorchToLinalg conversion for Aten_LocalScalarDenseOp

pull/2925/head
Max Dawkins 2024-02-15 14:25:49 -05:00
parent f8332a4da2
commit 2232780216
2 changed files with 5 additions and 2 deletions

View File

@ -10324,7 +10324,7 @@ def Torch_Aten_LocalScalarDenseOp : Torch_Op<"aten._local_scalar_dense", [
]> { ]> {
let summary = "Generated op for `aten::_local_scalar_dense : (Tensor) -> (Scalar)`"; let summary = "Generated op for `aten::_local_scalar_dense : (Tensor) -> (Scalar)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self AnyTorchTensorType:$a
); );
let results = (outs let results = (outs
AnyTorchScalarType:$result AnyTorchScalarType:$result

View File

@ -213,13 +213,16 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context); patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
target.addIllegalOp<AtenNumelOp>(); target.addIllegalOp<AtenNumelOp>();
patterns.add<ConvertAtenNumelOp>(typeConverter, context); patterns.add<ConvertAtenNumelOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp>(); target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp,
Aten_LocalScalarDenseOp>();
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(typeConverter, patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(typeConverter,
context); context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>( patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>(
typeConverter, context); typeConverter, context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(typeConverter, patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(typeConverter,
context); context);
patterns.add<ConvertAtenTensorToScalarLikeOp<Aten_LocalScalarDenseOp>>(
typeConverter, context);
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>(); target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context); patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>(); target.addIllegalOp<PrimNumToTensorScalarOp>();