Add folder for aten.broadcast_to on unchanged static shapes (#2421)

pull/2431/head
Quinn Dawkins 2023-09-01 14:50:34 -04:00 committed by GitHub
parent 34a0897e1b
commit 1fc4314b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 1 deletions

View File

@ -8356,6 +8356,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [

View File

@ -2317,6 +2317,25 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
return list.getElements()[0]; return list.getElements()[0];
} }
//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;
}
return getOperand(0);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSliceTensorOp // AtenSliceTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -542,7 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")

View File

@ -1971,6 +1971,18 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te
return %1: !torch.tensor return %1: !torch.tensor
} }
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int2 = torch.constant.int 2
%list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[3,4,2],f32>, !torch.list<int> -> !torch.vtensor<[3,4,2],f32>
return %0 : !torch.vtensor<[3,4,2],f32>
}
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>