mirror of https://github.com/llvm/torch-mlir
Add a number of kernels and new patterns.
* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_ * Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args. * The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen. * The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with. * More progress on #97pull/108/head
parent
3dab9056f0
commit
6c702b149f
|
@ -11,4 +11,5 @@ export PYTHONPATH="${build_dir}/python"
|
||||||
|
|
||||||
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
|
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
|
||||||
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
|
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
|
||||||
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc"
|
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" \
|
||||||
|
--debug_op_reg_file="${aten_dir}/ATenOpRegistrations.txt"
|
||||||
|
|
|
@ -189,9 +189,8 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
||||||
return constValue;
|
return constValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue
|
MlirValue FuncBuilder::buildList(MlirLocation loc,
|
||||||
FuncBuilder::buildList(MlirLocation loc,
|
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
|
||||||
MlirType resultType = npcompListTypeGet(context);
|
MlirType resultType = npcompListTypeGet(context);
|
||||||
OperationStateHolder state{"basicpy.build_list", loc};
|
OperationStateHolder state{"basicpy.build_list", loc};
|
||||||
mlirOperationStateAddResults(state, 1, &resultType);
|
mlirOperationStateAddResults(state, 1, &resultType);
|
||||||
|
|
|
@ -13,6 +13,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import traceback
|
||||||
|
|
||||||
# Note that this utility exists only in the c-extension.
|
# Note that this utility exists only in the c-extension.
|
||||||
from _torch_mlir import get_registered_ops
|
from _torch_mlir import get_registered_ops
|
||||||
|
@ -75,6 +76,55 @@ def generate_ops(g: "OpGenerator"):
|
||||||
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
|
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
|
||||||
f"{snakecase_to_camelcase(uname)}Op", uname)
|
f"{snakecase_to_camelcase(uname)}Op", uname)
|
||||||
|
|
||||||
|
# Convolution ops. Note that these are special in PyTorch and the importer,
|
||||||
|
# and we model them after the signatures of the convolution_overrideable
|
||||||
|
# ops (generic for non-CPU/GPU backends) but set the names according to
|
||||||
|
# how they come in.
|
||||||
|
g.print_banner("NN ops")
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
|
||||||
|
"ConvolutionOp",
|
||||||
|
"convolution",
|
||||||
|
alias_kernel_names=["aten::convolution"])
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::convolution_backward_overrideable(Tensor,Tensor,Tensor,int[],int[],int[],bool,int[],int,bool[])",
|
||||||
|
"ConvolutionBackwardOp",
|
||||||
|
"convolution_backward",
|
||||||
|
alias_kernel_names=["aten::convolution_backward"],
|
||||||
|
# These do return as None but are not coded optional in the registry :(
|
||||||
|
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
|
||||||
|
|
||||||
|
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
|
||||||
|
"LogSoftmaxOp", "log_softmax")
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
|
||||||
|
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
|
||||||
|
|
||||||
|
# Loss functions.
|
||||||
|
g.print_banner("Loss function ops")
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::nll_loss_forward(Tensor,Tensor,Tensor?,int,int)",
|
||||||
|
"NllLossForwardOp", "nll_loss_forward")
|
||||||
|
# Note also a grad_input 8-arg variant.
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::nll_loss_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||||
|
"NllLossBackwardOp", "nll_loss_backward")
|
||||||
|
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::nll_loss2d_forward(Tensor,Tensor,Tensor?,int,int)",
|
||||||
|
"NllLoss2dForwardOp", "nll_loss2d_forward")
|
||||||
|
# Note also a grad_input 8-arg variant.
|
||||||
|
g.ordinary_immutable_op(
|
||||||
|
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
|
||||||
|
"NllLoss2dBackwardOp", "nll_loss2d_backward")
|
||||||
|
|
||||||
|
# One-off in-place ops (note that many in-place arithmetic ops are handled
|
||||||
|
# as a transformation from their immutable forms).
|
||||||
|
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
|
||||||
|
"CopyInplaceOp",
|
||||||
|
"copy.inplace",
|
||||||
|
drop_arg_indices=[2])
|
||||||
|
|
||||||
|
|
||||||
def dump_registered_ops(outfile, reg_ops_dict):
|
def dump_registered_ops(outfile, reg_ops_dict):
|
||||||
for k in sorted(reg_ops_dict.keys()):
|
for k in sorted(reg_ops_dict.keys()):
|
||||||
|
@ -104,7 +154,20 @@ class OpGenerator:
|
||||||
)
|
)
|
||||||
em.print("")
|
em.print("")
|
||||||
|
|
||||||
def ordinary_binary_op(self, kernel_sig, ods_name, op_name):
|
def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||||
|
return InflightOpDef(self,
|
||||||
|
kernel_sig=kernel_sig,
|
||||||
|
ods_name=ods_name,
|
||||||
|
op_name=op_name,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def ordinary_binary_op(self,
|
||||||
|
kernel_sig,
|
||||||
|
ods_name,
|
||||||
|
op_name,
|
||||||
|
promote_trailing_out_tensor=True,
|
||||||
|
traits=(),
|
||||||
|
**kwargs):
|
||||||
""""Binary"-ops. These ops typically have:
|
""""Binary"-ops. These ops typically have:
|
||||||
- '.Tensor' variant where the second arg is a Tensor
|
- '.Tensor' variant where the second arg is a Tensor
|
||||||
- '.Scalar' variant where the second arg is a Scalar
|
- '.Scalar' variant where the second arg is a Scalar
|
||||||
|
@ -124,75 +187,130 @@ class OpGenerator:
|
||||||
- Setting all arguments and returns to kImmutableTensor
|
- Setting all arguments and returns to kImmutableTensor
|
||||||
- Enabling kPromoteScalarToTensor on the second argument.
|
- Enabling kPromoteScalarToTensor on the second argument.
|
||||||
"""
|
"""
|
||||||
reg_record = self._get_reg_record(kernel_sig)
|
opdef = self.define_op(
|
||||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
kernel_sig=kernel_sig,
|
||||||
reg_record["arguments"],
|
ods_name=ods_name,
|
||||||
type_transforms={
|
op_name=op_name,
|
||||||
"Tensor:0": "AnyTorchImmutableTensor",
|
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||||
"Tensor:1": "AnyTorchImmutableTensor",
|
traits=list(traits) + ["NoSideEffect"],
|
||||||
"Scalar:1": "AnyTorchImmutableTensor",
|
**kwargs)
|
||||||
"Scalar": "AnyTorchScalarType",
|
opdef.arg_transforms(type_transforms={
|
||||||
},
|
"Tensor:0": "AnyTorchImmutableTensor",
|
||||||
flag_transforms={
|
"Tensor:1": "AnyTorchImmutableTensor",
|
||||||
":0": ["kImmutableTensor"],
|
"Scalar:1": "AnyTorchImmutableTensor",
|
||||||
":1": ["kImmutableTensor", "kPromoteScalar"],
|
"Scalar": "AnyTorchScalarType",
|
||||||
})
|
},
|
||||||
ods_outs, return_type_flags = self._map_sigtypes(
|
flag_transforms={
|
||||||
reg_record["returns"],
|
":0": ["kImmutableTensor"],
|
||||||
type_transforms={
|
":1": ["kImmutableTensor", "kPromoteScalar"],
|
||||||
"Tensor:0": "AnyTorchImmutableTensor",
|
})
|
||||||
},
|
opdef.return_transforms(type_transforms={
|
||||||
flag_transforms={
|
"Tensor:0": "AnyTorchImmutableTensor",
|
||||||
":0": ["kImmutableTensor"],
|
},
|
||||||
})
|
flag_transforms={
|
||||||
self.ods_emitter.emit_opdef(ods_name,
|
":0": ["kImmutableTensor"],
|
||||||
op_name,
|
})
|
||||||
reg_record,
|
opdef.emit()
|
||||||
ods_ins=ods_ins,
|
|
||||||
ods_outs=ods_outs,
|
|
||||||
traits=["NoSideEffect"])
|
|
||||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
|
||||||
reg_record,
|
|
||||||
arg_type_flags=arg_type_flags,
|
|
||||||
return_type_flags=return_type_flags,
|
|
||||||
promote_trailing_out_tensor=True)
|
|
||||||
|
|
||||||
def ordinary_unary_op(self, kernel_sig, ods_name, op_name):
|
def ordinary_immutable_op(self,
|
||||||
|
kernel_sig,
|
||||||
|
ods_name,
|
||||||
|
op_name,
|
||||||
|
promote_trailing_out_tensor=True,
|
||||||
|
traits=(),
|
||||||
|
**kwargs):
|
||||||
|
""""An ordinary immutable-tensor based op."""
|
||||||
|
opdef = self.define_op(
|
||||||
|
kernel_sig=kernel_sig,
|
||||||
|
ods_name=ods_name,
|
||||||
|
op_name=op_name,
|
||||||
|
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||||
|
traits=list(traits) + ["NoSideEffect"],
|
||||||
|
**kwargs)
|
||||||
|
opdef.transforms(type_transforms={
|
||||||
|
"Tensor": "AnyTorchImmutableTensor",
|
||||||
|
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||||
|
"int": "AnyTorchIntType",
|
||||||
|
"int[]": "AnyTorchIntListType",
|
||||||
|
"bool": "AnyTorchBoolType",
|
||||||
|
"bool[]": "AnyTorchBoolListType",
|
||||||
|
},
|
||||||
|
flag_transforms={
|
||||||
|
"Tensor": ["kImmutableTensor"],
|
||||||
|
"Tensor?": ["kImmutableTensor"],
|
||||||
|
})
|
||||||
|
opdef.emit()
|
||||||
|
|
||||||
|
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
|
||||||
|
"""In-place ops (ending in '_').
|
||||||
|
|
||||||
|
These ops take a mutable first argument and then standard immutable
|
||||||
|
conversions for subsequent. When emitting into MLIR, the return value is
|
||||||
|
dropped.
|
||||||
|
"""
|
||||||
|
opdef = self.define_op(kernel_sig=kernel_sig,
|
||||||
|
ods_name=ods_name,
|
||||||
|
op_name=op_name,
|
||||||
|
**kwargs)
|
||||||
|
opdef.arg_transforms(type_transforms={
|
||||||
|
":0": "AnyTorchMutableTensor",
|
||||||
|
"Tensor": "AnyTorchImmutableTensor",
|
||||||
|
"Tensor?": "AnyTorchOptionalImmutableTensor",
|
||||||
|
"int": "AnyTorchIntType",
|
||||||
|
"int[]": "AnyTorchIntListType",
|
||||||
|
"bool": "AnyTorchBoolType",
|
||||||
|
"bool[]": "AnyTorchBoolListType",
|
||||||
|
},
|
||||||
|
flag_transforms={
|
||||||
|
":0": [],
|
||||||
|
"Tensor": ["kImmutableTensor"],
|
||||||
|
"Tensor?": ["kImmutableTensor"],
|
||||||
|
})
|
||||||
|
opdef.return_transforms(
|
||||||
|
type_transforms={
|
||||||
|
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
|
||||||
|
},
|
||||||
|
flag_transforms={
|
||||||
|
":0": ["kDropReturnAndAliasArg0"],
|
||||||
|
})
|
||||||
|
opdef.map_signatures()
|
||||||
|
opdef.ods_outs = [] # Clear the computed outs.
|
||||||
|
opdef.emit()
|
||||||
|
|
||||||
|
def ordinary_unary_op(self,
|
||||||
|
kernel_sig,
|
||||||
|
ods_name,
|
||||||
|
op_name,
|
||||||
|
promote_trailing_out_tensor=True,
|
||||||
|
traits=(),
|
||||||
|
**kwargs):
|
||||||
"""Unary ops.
|
"""Unary ops.
|
||||||
|
|
||||||
These take and return a tensor and typically have an out and inplace
|
These take and return a tensor and typically have an out and inplace
|
||||||
variant (they may not but we generate patterns to match anyway).
|
variant (they may not but we generate patterns to match anyway).
|
||||||
"""
|
"""
|
||||||
reg_record = self._get_reg_record(kernel_sig)
|
opdef = self.define_op(
|
||||||
ods_ins, arg_type_flags = self._map_sigtypes(
|
kernel_sig=kernel_sig,
|
||||||
reg_record["arguments"],
|
ods_name=ods_name,
|
||||||
type_transforms={
|
op_name=op_name,
|
||||||
"Tensor:0": "AnyTorchImmutableTensor",
|
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||||
},
|
traits=list(traits) + ["NoSideEffect"],
|
||||||
flag_transforms={
|
**kwargs)
|
||||||
":0": ["kImmutableTensor"],
|
opdef.arg_transforms(type_transforms={
|
||||||
})
|
"Tensor:0": "AnyTorchImmutableTensor",
|
||||||
ods_outs, return_type_flags = self._map_sigtypes(
|
},
|
||||||
reg_record["returns"],
|
flag_transforms={
|
||||||
type_transforms={
|
":0": ["kImmutableTensor"],
|
||||||
"Tensor:0": "AnyTorchImmutableTensor",
|
})
|
||||||
},
|
opdef.return_transforms(type_transforms={
|
||||||
flag_transforms={
|
"Tensor:0": "AnyTorchImmutableTensor",
|
||||||
":0": ["kImmutableTensor"],
|
},
|
||||||
})
|
flag_transforms={
|
||||||
self.ods_emitter.emit_opdef(ods_name,
|
":0": ["kImmutableTensor"],
|
||||||
op_name,
|
})
|
||||||
reg_record,
|
opdef.emit()
|
||||||
ods_ins=ods_ins,
|
|
||||||
ods_outs=ods_outs,
|
|
||||||
traits=["NoSideEffect"])
|
|
||||||
self.impl_emitter.emit_kernel_methods(ods_name,
|
|
||||||
reg_record,
|
|
||||||
arg_type_flags=arg_type_flags,
|
|
||||||
return_type_flags=return_type_flags,
|
|
||||||
promote_trailing_out_tensor=True)
|
|
||||||
|
|
||||||
def _get_reg_record(self, kernel_sig):
|
def get_reg_record(self, kernel_sig):
|
||||||
"""Gets the op-dict for a given registered op name.
|
"""Gets the op-dict for a given registered op name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -212,8 +330,14 @@ class OpGenerator:
|
||||||
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
|
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
|
||||||
f"Possible matches:\n {dym_message}")
|
f"Possible matches:\n {dym_message}")
|
||||||
|
|
||||||
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str],
|
def _map_sigtypes(
|
||||||
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]:
|
self,
|
||||||
|
siglist: List[Dict],
|
||||||
|
type_transforms: Dict[str, str],
|
||||||
|
flag_transforms: Dict[str, List[str]],
|
||||||
|
drop_indices: Sequence[int] = (),
|
||||||
|
override_types: Optional[Sequence[str]] = None,
|
||||||
|
) -> List[Tuple[str]]:
|
||||||
"""Maps a list of signature entries to ods dags and flag lists.
|
"""Maps a list of signature entries to ods dags and flag lists.
|
||||||
|
|
||||||
The torch signature list contains dicts that minimally have keys 'name' and
|
The torch signature list contains dicts that minimally have keys 'name' and
|
||||||
|
@ -233,15 +357,23 @@ class OpGenerator:
|
||||||
- An ods dag list of (ods_name, ods_type) tuples
|
- An ods dag list of (ods_name, ods_type) tuples
|
||||||
- List of (torch_type, [conversion_flag]) for specifying conversions.
|
- List of (torch_type, [conversion_flag]) for specifying conversions.
|
||||||
"""
|
"""
|
||||||
|
# Make sure any override types are sane.
|
||||||
|
if override_types:
|
||||||
|
assert len(override_types) == len(siglist), (
|
||||||
|
"Mismatch override and actual types")
|
||||||
# Generate to ods dag list.
|
# Generate to ods dag list.
|
||||||
ods_dag_list = []
|
ods_dag_list = []
|
||||||
for i, sigitem in enumerate(siglist):
|
for i, sigitem in enumerate(siglist):
|
||||||
|
if i in drop_indices:
|
||||||
|
# Do not emit in ODS.
|
||||||
|
continue
|
||||||
torch_name = sigitem["name"]
|
torch_name = sigitem["name"]
|
||||||
torch_type = sigitem["type"]
|
torch_type = (sigitem["type"]
|
||||||
|
if override_types is None else override_types[i])
|
||||||
# Look up the type transform.
|
# Look up the type transform.
|
||||||
ods_type = (type_transforms.get(f"{torch_type}:{i}") or
|
ods_type = _first_non_none(type_transforms.get(f"{torch_type}:{i}"),
|
||||||
type_transforms.get(f":{i}") or
|
type_transforms.get(f":{i}"),
|
||||||
type_transforms.get(torch_type))
|
type_transforms.get(torch_type))
|
||||||
if not ods_type:
|
if not ods_type:
|
||||||
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
|
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
|
||||||
f"a type transform {type_transforms}")
|
f"a type transform {type_transforms}")
|
||||||
|
@ -250,16 +382,130 @@ class OpGenerator:
|
||||||
# Generate the type conversion flags.
|
# Generate the type conversion flags.
|
||||||
type_flag_list = []
|
type_flag_list = []
|
||||||
for i, sigitem in enumerate(siglist):
|
for i, sigitem in enumerate(siglist):
|
||||||
torch_type = sigitem["type"]
|
torch_type = (sigitem["type"]
|
||||||
|
if override_types is None else override_types[i])
|
||||||
# Look up the type transform.
|
# Look up the type transform.
|
||||||
flags = (flag_transforms.get(f"{torch_type}:{i}") or
|
if i in drop_indices:
|
||||||
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type))
|
flags = ["kDrop"]
|
||||||
if not flags:
|
else:
|
||||||
flags = []
|
flags = _first_non_none(flag_transforms.get(f"{torch_type}:{i}"),
|
||||||
|
flag_transforms.get(f":{i}"),
|
||||||
|
flag_transforms.get(torch_type))
|
||||||
|
if flags is None:
|
||||||
|
flags = []
|
||||||
type_flag_list.append((torch_type, flags))
|
type_flag_list.append((torch_type, flags))
|
||||||
return ods_dag_list, type_flag_list
|
return ods_dag_list, type_flag_list
|
||||||
|
|
||||||
|
|
||||||
|
class InflightOpDef:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
g: OpGenerator,
|
||||||
|
kernel_sig,
|
||||||
|
ods_name,
|
||||||
|
op_name,
|
||||||
|
traits=(),
|
||||||
|
alias_kernel_names=(),
|
||||||
|
promote_trailing_out_tensor=False,
|
||||||
|
override_arg_types=None,
|
||||||
|
override_return_types=None,
|
||||||
|
drop_arg_indices=(),
|
||||||
|
drop_return_indices=()):
|
||||||
|
super().__init__()
|
||||||
|
self.g = g
|
||||||
|
self.kernel_sig = kernel_sig
|
||||||
|
self.ods_name = ods_name
|
||||||
|
self.op_name = op_name
|
||||||
|
self.traits = list(traits)
|
||||||
|
self.alias_kernel_names = list(alias_kernel_names)
|
||||||
|
self.promote_trailing_out_tensor = promote_trailing_out_tensor
|
||||||
|
self.override_arg_types = override_arg_types
|
||||||
|
self.override_return_types = override_return_types
|
||||||
|
self.drop_arg_indices = drop_arg_indices
|
||||||
|
self.drop_return_indices = drop_return_indices
|
||||||
|
self.reg_record = g.get_reg_record(self.kernel_sig)
|
||||||
|
self._emitted = False
|
||||||
|
self._traceback = traceback.extract_stack()[0:-2]
|
||||||
|
|
||||||
|
# Arg and flag transform dicts.
|
||||||
|
self.arg_type_transforms = dict()
|
||||||
|
self.arg_flag_transforms = dict()
|
||||||
|
self.return_type_transforms = dict()
|
||||||
|
self.return_flag_trasforms = dict()
|
||||||
|
|
||||||
|
# Signature mapping.
|
||||||
|
self._sigs_mapped = False
|
||||||
|
self.ods_ins = None
|
||||||
|
self.ods_outs = None
|
||||||
|
self.arg_type_flags = None
|
||||||
|
self.return_type_flags = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if not self._emitted:
|
||||||
|
print("WARNING: Op defined but not emitted. Defined at:", file=sys.stderr)
|
||||||
|
for line in traceback.format_list(self._traceback):
|
||||||
|
sys.stderr.write(line)
|
||||||
|
|
||||||
|
def transforms(self, type_transforms=None, flag_transforms=None):
|
||||||
|
self.arg_transforms(type_transforms=type_transforms,
|
||||||
|
flag_transforms=flag_transforms)
|
||||||
|
self.return_transforms(type_transforms=type_transforms,
|
||||||
|
flag_transforms=flag_transforms)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def arg_transforms(self, type_transforms=None, flag_transforms=None):
|
||||||
|
"""Adds arg type and flag transforms dicts."""
|
||||||
|
if type_transforms:
|
||||||
|
self.arg_type_transforms.update(type_transforms)
|
||||||
|
if flag_transforms:
|
||||||
|
self.arg_flag_transforms.update(flag_transforms)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def return_transforms(self, type_transforms=None, flag_transforms=None):
|
||||||
|
"""Adds return type and flag transform dicts."""
|
||||||
|
if type_transforms:
|
||||||
|
self.return_type_transforms.update(type_transforms)
|
||||||
|
if flag_transforms:
|
||||||
|
self.return_flag_trasforms.update(flag_transforms)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def map_signatures(self):
|
||||||
|
assert not self._sigs_mapped, "Signatures already mapped"
|
||||||
|
self._sigs_mapped = True
|
||||||
|
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
|
||||||
|
self.reg_record["arguments"],
|
||||||
|
type_transforms=self.arg_type_transforms,
|
||||||
|
flag_transforms=self.arg_flag_transforms,
|
||||||
|
override_types=self.override_arg_types,
|
||||||
|
drop_indices=self.drop_arg_indices)
|
||||||
|
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
|
||||||
|
self.reg_record["returns"],
|
||||||
|
type_transforms=self.return_type_transforms,
|
||||||
|
flag_transforms=self.return_flag_trasforms,
|
||||||
|
override_types=self.override_return_types,
|
||||||
|
drop_indices=self.drop_return_indices)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def emit(self):
|
||||||
|
assert not self._emitted, "Op already emitted"
|
||||||
|
self._emitted = True
|
||||||
|
if not self._sigs_mapped:
|
||||||
|
self.map_signatures()
|
||||||
|
self.g.ods_emitter.emit_opdef(self.ods_name,
|
||||||
|
self.op_name,
|
||||||
|
self.reg_record,
|
||||||
|
ods_ins=self.ods_ins,
|
||||||
|
ods_outs=self.ods_outs,
|
||||||
|
traits=self.traits)
|
||||||
|
self.g.impl_emitter.emit_kernel_methods(
|
||||||
|
self.ods_name,
|
||||||
|
self.reg_record,
|
||||||
|
arg_type_flags=self.arg_type_flags,
|
||||||
|
return_type_flags=self.return_type_flags,
|
||||||
|
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
|
||||||
|
alias_kernel_names=self.alias_kernel_names)
|
||||||
|
|
||||||
|
|
||||||
class EmitterBase:
|
class EmitterBase:
|
||||||
_INDENT = " "
|
_INDENT = " "
|
||||||
|
|
||||||
|
@ -355,7 +601,8 @@ class CCImplEmitter(EmitterBase):
|
||||||
reg_record,
|
reg_record,
|
||||||
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||||
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||||
promote_trailing_out_tensor=False):
|
promote_trailing_out_tensor=False,
|
||||||
|
alias_kernel_names: Sequence[str] = ()):
|
||||||
# getTorchKernelMetadata() method.
|
# getTorchKernelMetadata() method.
|
||||||
self.print(
|
self.print(
|
||||||
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
|
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
|
||||||
|
@ -374,6 +621,9 @@ class CCImplEmitter(EmitterBase):
|
||||||
with self.indent():
|
with self.indent():
|
||||||
self.print("Torch::BuildKernelMetadata m;")
|
self.print("Torch::BuildKernelMetadata m;")
|
||||||
self.print(f"m.kernelName = {self.quote(kernel_name)};")
|
self.print(f"m.kernelName = {self.quote(kernel_name)};")
|
||||||
|
for alias_kernel_name in alias_kernel_names:
|
||||||
|
self.print(
|
||||||
|
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
|
||||||
if promote_trailing_out_tensor:
|
if promote_trailing_out_tensor:
|
||||||
self.print("m.promoteTrailingOutTensor = true;")
|
self.print("m.promoteTrailingOutTensor = true;")
|
||||||
# Arg types/flags.
|
# Arg types/flags.
|
||||||
|
@ -393,7 +643,7 @@ class CCImplEmitter(EmitterBase):
|
||||||
self.print("return m;")
|
self.print("return m;")
|
||||||
self.print("})();")
|
self.print("})();")
|
||||||
self.print("return metadata;")
|
self.print("return metadata;")
|
||||||
self.print("}")
|
self.print("}\n")
|
||||||
|
|
||||||
def _format_cpp_str_initlist(self, strings):
|
def _format_cpp_str_initlist(self, strings):
|
||||||
quoted = [self.quote(s) for s in strings]
|
quoted = [self.quote(s) for s in strings]
|
||||||
|
@ -416,6 +666,13 @@ def snakecase_to_camelcase(ident: str):
|
||||||
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
|
||||||
|
|
||||||
|
|
||||||
|
def _first_non_none(*args):
|
||||||
|
for arg in args:
|
||||||
|
if arg is not None:
|
||||||
|
return arg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _load_ops_as_dict():
|
def _load_ops_as_dict():
|
||||||
# Returns a list of dicts, each with a name that is a tuple of the form:
|
# Returns a list of dicts, each with a name that is a tuple of the form:
|
||||||
# (kernel_signature, variant)
|
# (kernel_signature, variant)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
ATenOpRegistrations.txt
|
|
@ -61,55 +61,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Our jit library only supports 6 argument convolutions, rather than 9
|
|
||||||
// arguments supported by pytorch. This operation allows us to represent this
|
|
||||||
// limitation temporarily.
|
|
||||||
def aten_ConvolutionOp: aten_Op<"_convolution", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$input,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyTensor:$bias,
|
|
||||||
AnyType:$stride,
|
|
||||||
AnyType:$padding,
|
|
||||||
AnyType:$dilation
|
|
||||||
);
|
|
||||||
|
|
||||||
let summary = "Convolution operator";
|
|
||||||
let description = [{
|
|
||||||
Convolution operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
uint64_t getOperandTransferVolume(unsigned int idx, bool read);
|
|
||||||
uint64_t getResultTransferVolume(unsigned int idx, bool read);
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Our jit library only supports 6 argument convolutions, rather than 9
|
|
||||||
// arguments supported by pytorch. This operation allows us to represent this
|
|
||||||
// limitation temporarily.
|
|
||||||
def aten_ConvolutionBackwardOp: aten_Op<"_convolution_backward", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor:$dx, AnyTensor:$dw, AnyTensor:$db)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$grad_output,
|
|
||||||
AnyTensor:$input,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyType:$stride,
|
|
||||||
AnyType:$padding,
|
|
||||||
AnyType:$dilation
|
|
||||||
);
|
|
||||||
|
|
||||||
let summary = "ConvolutionBackward operator";
|
|
||||||
let description = [{
|
|
||||||
ConvolutionBackward operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
|
|
@ -37,6 +37,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
|
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -55,6 +56,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -73,6 +75,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -91,6 +94,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -109,6 +113,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -127,6 +132,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -145,6 +151,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Unary arithmetic ops
|
// Unary arithmetic ops
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -167,6 +174,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -185,6 +193,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -203,6 +212,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -221,6 +231,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -239,6 +250,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -257,6 +269,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -275,6 +288,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -293,6 +307,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -311,6 +326,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -329,6 +345,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -347,6 +364,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -365,6 +383,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -383,6 +402,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -401,6 +421,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
|
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -419,6 +440,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -437,6 +459,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -455,6 +478,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -473,6 +497,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -491,6 +516,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
|
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -509,6 +535,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -527,6 +554,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
|
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -545,6 +573,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -563,6 +592,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -581,6 +611,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -599,6 +630,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -617,6 +649,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -635,6 +668,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -653,6 +687,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -671,6 +706,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -689,6 +725,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -707,6 +744,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -725,6 +763,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -743,6 +782,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -761,6 +801,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
|
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
|
||||||
return getTorchBuildKernelMetadata();
|
return getTorchBuildKernelMetadata();
|
||||||
}
|
}
|
||||||
|
@ -779,3 +820,184 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
|
||||||
})();
|
})();
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// NN ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::convolution_overrideable";
|
||||||
|
m.aliasKernelNames.push_back("aten::convolution");
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::convolution_backward_overrideable";
|
||||||
|
m.aliasKernelNames.push_back("aten::convolution_backward");
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::_log_softmax";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "int", "bool"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::_log_softmax_backward_data";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Loss function ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::nll_loss_forward";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor", "Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::nll_loss_backward";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::nll_loss2d_forward";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
|
||||||
|
m.addReturnTypes({"Tensor", "Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::nll_loss2d_backward";
|
||||||
|
m.promoteTrailingOutTensor = true;
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
|
||||||
|
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kImmutableTensor});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
|
||||||
|
return getTorchBuildKernelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() {
|
||||||
|
using KVC = Torch::KernelValueConversion::BitMask;
|
||||||
|
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||||
|
Torch::BuildKernelMetadata m;
|
||||||
|
m.kernelName = "aten::copy_";
|
||||||
|
m.addArgTypes({"Tensor", "Tensor", "bool"});
|
||||||
|
m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop});
|
||||||
|
m.addReturnTypes({"Tensor"});
|
||||||
|
m.addReturnConversions({KVC::kDropReturnAndAliasArg0});
|
||||||
|
return m;
|
||||||
|
})();
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -450,3 +450,147 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// NN ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::convolution_overrideable";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$input,
|
||||||
|
AnyTorchImmutableTensor:$weight,
|
||||||
|
AnyTorchOptionalImmutableTensor:$bias,
|
||||||
|
AnyTorchIntListType:$stride,
|
||||||
|
AnyTorchIntListType:$padding,
|
||||||
|
AnyTorchIntListType:$dilation,
|
||||||
|
AnyTorchBoolType:$transposed,
|
||||||
|
AnyTorchIntListType:$output_padding,
|
||||||
|
AnyTorchIntType:$groups
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$grad_output,
|
||||||
|
AnyTorchImmutableTensor:$input,
|
||||||
|
AnyTorchImmutableTensor:$weight,
|
||||||
|
AnyTorchIntListType:$stride,
|
||||||
|
AnyTorchIntListType:$padding,
|
||||||
|
AnyTorchIntListType:$dilation,
|
||||||
|
AnyTorchBoolType:$transposed,
|
||||||
|
AnyTorchIntListType:$output_padding,
|
||||||
|
AnyTorchIntType:$groups,
|
||||||
|
AnyTorchBoolListType:$output_mask
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalImmutableTensor:$grad_input,
|
||||||
|
AnyTorchOptionalImmutableTensor:$grad_weight,
|
||||||
|
AnyTorchOptionalImmutableTensor:$grad_bias
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_LogSoftmaxOp: aten_Op<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::_log_softmax";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchIntType:$dim,
|
||||||
|
AnyTorchBoolType:$half_to_float
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::_log_softmax_backward_data";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$grad_output,
|
||||||
|
AnyTorchImmutableTensor:$output,
|
||||||
|
AnyTorchIntType:$dim,
|
||||||
|
AnyTorchImmutableTensor:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Loss function ops
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::nll_loss_forward";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$target,
|
||||||
|
AnyTorchOptionalImmutableTensor:$weight,
|
||||||
|
AnyTorchIntType:$reduction,
|
||||||
|
AnyTorchIntType:$ignore_index
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor:$output,
|
||||||
|
AnyTorchImmutableTensor:$total_weight
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::nll_loss_backward";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$grad_output,
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$target,
|
||||||
|
AnyTorchOptionalImmutableTensor:$weight,
|
||||||
|
AnyTorchIntType:$reduction,
|
||||||
|
AnyTorchIntType:$ignore_index,
|
||||||
|
AnyTorchImmutableTensor:$total_weight
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::nll_loss2d_forward";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$target,
|
||||||
|
AnyTorchOptionalImmutableTensor:$weight,
|
||||||
|
AnyTorchIntType:$reduction,
|
||||||
|
AnyTorchIntType:$ignore_index
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor:$output,
|
||||||
|
AnyTorchImmutableTensor:$total_weight
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::nll_loss2d_backward";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchImmutableTensor:$grad_output,
|
||||||
|
AnyTorchImmutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$target,
|
||||||
|
AnyTorchOptionalImmutableTensor:$weight,
|
||||||
|
AnyTorchIntType:$reduction,
|
||||||
|
AnyTorchIntType:$ignore_index,
|
||||||
|
AnyTorchImmutableTensor:$total_weight
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchImmutableTensor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||||
|
let summary = "Recognized op for kernel aten::copy_";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchMutableTensor:$self,
|
||||||
|
AnyTorchImmutableTensor:$src
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,52 +44,6 @@ def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$input,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyTensor:$bias,
|
|
||||||
AnyType:$stride,
|
|
||||||
AnyType:$padding,
|
|
||||||
AnyType:$dilation,
|
|
||||||
AnyScalar:$transposed,
|
|
||||||
AnyType:$output_padding,
|
|
||||||
AnyScalar:$groups
|
|
||||||
);
|
|
||||||
let summary = "aten convolution_overrideable operator";
|
|
||||||
let description = [{
|
|
||||||
ConvolutionOverrideableOp
|
|
||||||
aten convolution_overrideable operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$grad_output,
|
|
||||||
AnyTensor:$input,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyType:$stride,
|
|
||||||
AnyType:$padding,
|
|
||||||
AnyType:$dilation,
|
|
||||||
AnyScalar:$transposed,
|
|
||||||
AnyType:$output_padding,
|
|
||||||
AnyScalar:$groups
|
|
||||||
);
|
|
||||||
let summary = "aten convolution_backward_overrideable operator";
|
|
||||||
let description = [{
|
|
||||||
ConvolutionBackwardOverrideableOp
|
|
||||||
aten convolution_backward_overrideable operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
@ -123,35 +77,6 @@ def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$self,
|
|
||||||
AnyScalar:$dim,
|
|
||||||
AnyScalar:$half_to_float
|
|
||||||
);
|
|
||||||
let summary = "aten _log_softmax operator";
|
|
||||||
let description = [{
|
|
||||||
LogSoftmaxOp
|
|
||||||
aten _log_softmax operator
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$grad_output,
|
|
||||||
AnyTensor:$output,
|
|
||||||
AnyScalar:$dim,
|
|
||||||
AnyTensor:$self
|
|
||||||
);
|
|
||||||
let summary = "aten _log_softmax_backward_data operator";
|
|
||||||
let description = [{
|
|
||||||
LogSoftmaxBackwardDataOp
|
|
||||||
aten _log_softmax_backward_data operator
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
@ -445,86 +370,6 @@ def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor, AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$self,
|
|
||||||
AnyTensor:$target,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyScalar:$reduction,
|
|
||||||
AnyScalar:$ignore_index
|
|
||||||
);
|
|
||||||
let summary = "aten nll_loss_forward operator";
|
|
||||||
let description = [{
|
|
||||||
NllLossForwardOp
|
|
||||||
aten nll_loss_forward operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$grad_output,
|
|
||||||
AnyTensor:$self,
|
|
||||||
AnyTensor:$target,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyScalar:$reduction,
|
|
||||||
AnyScalar:$ignore_index,
|
|
||||||
AnyTensor:$total_weight
|
|
||||||
);
|
|
||||||
let summary = "aten nll_loss_backward operator";
|
|
||||||
let description = [{
|
|
||||||
NllLossBackwardOp
|
|
||||||
aten nll_loss_backward operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor, AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$self,
|
|
||||||
AnyTensor:$target,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyScalar:$reduction,
|
|
||||||
AnyScalar:$ignore_index
|
|
||||||
);
|
|
||||||
let summary = "aten nll_loss2d_forward operator";
|
|
||||||
let description = [{
|
|
||||||
NllLoss2dForwardOp
|
|
||||||
aten nll_loss2d_forward operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
|
|
||||||
Results<(outs AnyTensor)> {
|
|
||||||
let arguments = (
|
|
||||||
ins AnyTensor:$grad_output,
|
|
||||||
AnyTensor:$self,
|
|
||||||
AnyTensor:$target,
|
|
||||||
AnyTensor:$weight,
|
|
||||||
AnyScalar:$reduction,
|
|
||||||
AnyScalar:$ignore_index,
|
|
||||||
AnyTensor:$total_weight
|
|
||||||
);
|
|
||||||
let summary = "aten nll_loss2d_backward operator";
|
|
||||||
let description = [{
|
|
||||||
NllLoss2dBackwardOp
|
|
||||||
aten nll_loss2d_backward operator
|
|
||||||
}];
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, uint64_t> getStatistics();
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
|
|
@ -12,15 +12,4 @@
|
||||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "npcomp/Dialect/ATen/IR/ATenOps.td"
|
include "npcomp/Dialect/ATen/IR/ATenOps.td"
|
||||||
|
|
||||||
// The pytorch convolution operator has 9 arguments, but we only have a jit
|
|
||||||
// library that supports the first six at the moment.
|
|
||||||
def : Pat<(aten_ConvolutionOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
|
|
||||||
$a7, $a8, $a9),
|
|
||||||
(aten_ConvolutionOp $a1, $a2, $a3, $a4, $a5, $a6)>;
|
|
||||||
|
|
||||||
def : Pat<(aten_ConvolutionBackwardOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
|
|
||||||
$a7, $a8, $a9),
|
|
||||||
(aten_ConvolutionBackwardOp $a1, $a2, $a3, $a4, $a5, $a6)>;
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -18,7 +18,7 @@ namespace Torch {
|
||||||
|
|
||||||
/// Conversion rule to apply to a value (argument or return).
|
/// Conversion rule to apply to a value (argument or return).
|
||||||
namespace KernelValueConversion {
|
namespace KernelValueConversion {
|
||||||
enum BitMask {
|
enum BitMask : uint32_t {
|
||||||
// No coercion.
|
// No coercion.
|
||||||
kNone = 0,
|
kNone = 0,
|
||||||
|
|
||||||
|
@ -32,7 +32,16 @@ enum BitMask {
|
||||||
// to a 0d tensor.
|
// to a 0d tensor.
|
||||||
kPromoteScalar = 8,
|
kPromoteScalar = 8,
|
||||||
|
|
||||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
|
// Drops the return value and aliases to argument 0.
|
||||||
|
// TODO: Remove this in favor of general alias metadata processing (note that
|
||||||
|
// the vast majority are this one case so it isn't so bad to have a special
|
||||||
|
// case for it if necessary).
|
||||||
|
kDropReturnAndAliasArg0 = 16,
|
||||||
|
|
||||||
|
// Drops the argument/return.
|
||||||
|
kDrop = 32,
|
||||||
|
|
||||||
|
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kDrop)
|
||||||
};
|
};
|
||||||
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
||||||
} // namespace KernelValueConversion
|
} // namespace KernelValueConversion
|
||||||
|
@ -74,6 +83,9 @@ struct BuildKernelMetadata : public KernelMetadata {
|
||||||
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
|
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
|
||||||
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
|
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
|
||||||
|
|
||||||
|
/// Additional alias kernel names to match.
|
||||||
|
SmallVector<StringRef, 1> aliasKernelNames;
|
||||||
|
|
||||||
void addArgConversions(
|
void addArgConversions(
|
||||||
std::initializer_list<KernelValueConversion::BitMask> ilist) {
|
std::initializer_list<KernelValueConversion::BitMask> ilist) {
|
||||||
argConversions.insert(argConversions.end(), ilist);
|
argConversions.insert(argConversions.end(), ilist);
|
||||||
|
|
|
@ -72,6 +72,11 @@ def AnyTorchImmutableTensor : AnyTypeOf<[
|
||||||
AnyTensor,
|
AnyTensor,
|
||||||
], "allowable torch immutable tensor">;
|
], "allowable torch immutable tensor">;
|
||||||
|
|
||||||
|
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
|
||||||
|
AnyTorchImmutableTensor,
|
||||||
|
Basicpy_NoneType,
|
||||||
|
], "allowable torch immutable tensor (or None)">;
|
||||||
|
|
||||||
def AnyTorchMutableTensor : AnyTypeOf<[
|
def AnyTorchMutableTensor : AnyTypeOf<[
|
||||||
// "Numpy-style" mutable NDArray. While not offering the full generality
|
// "Numpy-style" mutable NDArray. While not offering the full generality
|
||||||
// of a Torch tensor, it models the same access patterns and implies the
|
// of a Torch tensor, it models the same access patterns and implies the
|
||||||
|
@ -95,7 +100,28 @@ def AnyTorchScalarType : AnyTypeOf<[
|
||||||
AnySignlessInteger,
|
AnySignlessInteger,
|
||||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||||
|
|
||||||
|
def AnyTorchBoolType : AnyTypeOf<[
|
||||||
|
I1,
|
||||||
|
Basicpy_BoolType,
|
||||||
|
], "Any permissible bool type">;
|
||||||
|
|
||||||
|
def AnyTorchBoolListType : AnyTypeOf<[
|
||||||
|
Basicpy_ListType,
|
||||||
|
// TODO: Support typed list when available.
|
||||||
|
], "Any bool list type (bool[])">;
|
||||||
|
|
||||||
|
def AnyTorchIntType : AnyTypeOf<[
|
||||||
|
AnySignedInteger,
|
||||||
|
AnySignlessInteger,
|
||||||
|
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
|
||||||
|
|
||||||
|
def AnyTorchIntListType : AnyTypeOf<[
|
||||||
|
Basicpy_ListType,
|
||||||
|
// TODO: Support typed list when available.
|
||||||
|
], "Any int list type (int[])">;
|
||||||
|
|
||||||
def AnyTorchType : AnyTypeOf<[
|
def AnyTorchType : AnyTypeOf<[
|
||||||
|
AnyTorchBoolType,
|
||||||
AnyTorchScalarType,
|
AnyTorchScalarType,
|
||||||
AnyTorchTensorType,
|
AnyTorchTensorType,
|
||||||
Basicpy_ListType,
|
Basicpy_ListType,
|
||||||
|
|
|
@ -170,40 +170,6 @@ std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// _convolution
|
|
||||||
std::map<std::string, uint64_t> ConvolutionOp::getStatistics() {
|
|
||||||
return getConv2dStatistics(this, /*groups*/ 1);
|
|
||||||
}
|
|
||||||
std::map<std::string, uint64_t> ConvolutionOverrideableOp::getStatistics() {
|
|
||||||
// FIXME
|
|
||||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
|
||||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
|
||||||
uint64_t groups = ia.getValue().getZExtValue();
|
|
||||||
|
|
||||||
return getConv2dStatistics(this, groups);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t ConvolutionOp::getOperandTransferVolume(unsigned int idx, bool read) {
|
|
||||||
return getConv2dOperandTransferVolume<ConvolutionOp>(this, idx, read);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
|
|
||||||
return getConv2dResultTransferVolume<ConvolutionOp>(this, idx, write);
|
|
||||||
}
|
|
||||||
|
|
||||||
// _convolution_backward
|
|
||||||
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
|
|
||||||
return getConv2dBackwardStatistics(*this, 1);
|
|
||||||
}
|
|
||||||
std::map<std::string, uint64_t>
|
|
||||||
ConvolutionBackwardOverrideableOp::getStatistics() {
|
|
||||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
|
||||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
|
||||||
uint64_t groups = ia.getValue().getZExtValue();
|
|
||||||
|
|
||||||
return getConv2dBackwardStatistics(*this, groups);
|
|
||||||
}
|
|
||||||
|
|
||||||
// div_
|
// div_
|
||||||
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
|
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
|
||||||
|
|
||||||
|
@ -559,35 +525,6 @@ std::map<std::string, uint64_t> NativeBatchNormBackwardOp::getStatistics() {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, uint64_t> NllLossForwardOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
// FIXME: unimplemented
|
|
||||||
toReturn["reads"] = -1;
|
|
||||||
toReturn["writes"] = -1;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
std::map<std::string, uint64_t> NllLossBackwardOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
// FIXME: unimplemented
|
|
||||||
toReturn["reads"] = -1;
|
|
||||||
toReturn["writes"] = -1;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
std::map<std::string, uint64_t> NllLoss2dForwardOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
// FIXME: unimplemented
|
|
||||||
toReturn["reads"] = -1;
|
|
||||||
toReturn["writes"] = -1;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
|
||||||
// FIXME: unimplemented
|
|
||||||
toReturn["reads"] = -1;
|
|
||||||
toReturn["writes"] = -1;
|
|
||||||
return toReturn;
|
|
||||||
}
|
|
||||||
|
|
||||||
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
|
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
|
||||||
// return getReLUOpStatistics(*this);
|
// return getReLUOpStatistics(*this);
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||||
|
@ -28,6 +29,14 @@ using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
bool isTorchTensorType(StringRef torchType) {
|
||||||
|
return torchType == "Tensor" || torchType == "Tensor?";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isTorchOptionalType(StringRef torchType) {
|
||||||
|
return torchType.endswith("?");
|
||||||
|
}
|
||||||
|
|
||||||
struct TypeConversion {
|
struct TypeConversion {
|
||||||
Type targetType;
|
Type targetType;
|
||||||
std::function<Value(Location loc, Value originalValue,
|
std::function<Value(Location loc, Value originalValue,
|
||||||
|
@ -49,9 +58,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
|
||||||
// Immutable tensor conversion.
|
// Immutable tensor conversion.
|
||||||
if (flag & KVC::kImmutableTensor) {
|
if (flag & KVC::kImmutableTensor) {
|
||||||
// TODO: Support the kPromoteScalar flag.
|
// TODO: Support the kPromoteScalar flag.
|
||||||
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
|
if (!isTorchTensorType(sourceTorchType) ||
|
||||||
|
!isTorchTensorType(targetTorchType))
|
||||||
return None;
|
return None;
|
||||||
|
|
||||||
|
// If the target is optional and the type is NoneType, passthrough.
|
||||||
|
if (isTorchOptionalType(targetTorchType) &&
|
||||||
|
sourceMlirType.isa<Basicpy::NoneType>())
|
||||||
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
|
||||||
// Already immutable.
|
// Already immutable.
|
||||||
if (sourceMlirType.isa<TensorType>())
|
if (sourceMlirType.isa<TensorType>())
|
||||||
return TypeConversion{sourceMlirType, nullptr};
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
@ -86,30 +101,51 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
|
||||||
Type sourceMlirType) {
|
Type sourceMlirType) {
|
||||||
using KVC = KernelValueConversion::BitMask;
|
using KVC = KernelValueConversion::BitMask;
|
||||||
// Default trivial case.
|
// Default trivial case.
|
||||||
if (sourceTorchType == targetTorchType && flag == 0)
|
if (sourceTorchType == targetTorchType && flag == 0) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * Return types already match\n");
|
||||||
return TypeConversion{sourceMlirType, nullptr};
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
// Immutable tensor conversion.
|
// Immutable tensor conversion.
|
||||||
if (flag & KVC::kImmutableTensor) {
|
if (flag & KVC::kImmutableTensor) {
|
||||||
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " * Return conversion flag kImmutableTensor\n");
|
||||||
|
if (!isTorchTensorType(sourceTorchType) ||
|
||||||
|
!isTorchTensorType(targetTorchType)) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " * Source or target not a Tensor type\n");
|
||||||
return None;
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// Already immutable.
|
// Already immutable.
|
||||||
if (sourceMlirType.isa<TensorType>())
|
if (sourceMlirType.isa<TensorType>()) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * Source is already immutable\n");
|
||||||
return TypeConversion{sourceMlirType, nullptr};
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
// Convert NdArray type.
|
// Convert NdArray type.
|
||||||
if (auto ndArrayType = sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
|
if (sourceMlirType.isa<Basicpy::NoneType>() &&
|
||||||
|
isTorchOptionalType(targetTorchType)) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * None Tensor type passthrough\n");
|
||||||
|
return TypeConversion{sourceMlirType, nullptr};
|
||||||
|
} else if (auto ndArrayType =
|
||||||
|
sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
|
||||||
auto tensorType = ndArrayType.toTensorType();
|
auto tensorType = ndArrayType.toTensorType();
|
||||||
auto callback = [=](Location loc, Value newOpResultValue,
|
auto callback = [=](Location loc, Value newOpResultValue,
|
||||||
PatternRewriter &rewriter) -> Value {
|
PatternRewriter &rewriter) -> Value {
|
||||||
return rewriter.create<Numpy::CreateArrayFromTensorOp>(
|
return rewriter.create<Numpy::CreateArrayFromTensorOp>(
|
||||||
loc, ndArrayType, newOpResultValue);
|
loc, ndArrayType, newOpResultValue);
|
||||||
};
|
};
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * Convert return type\n");
|
||||||
return TypeConversion{tensorType, callback};
|
return TypeConversion{tensorType, callback};
|
||||||
|
} else {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " * Return type is not a supported tensor type\n");
|
||||||
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,9 +178,16 @@ public:
|
||||||
const BuildKernelMetadata &buildMetadata) {
|
const BuildKernelMetadata &buildMetadata) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "Register kernel call translation for: " << opName << "\n");
|
<< "Register kernel call translation for: " << opName << "\n");
|
||||||
CandidateTransformList &candidates =
|
{
|
||||||
kernelTransforms[buildMetadata.kernelName];
|
CandidateTransformList &candidates =
|
||||||
candidates.emplace_back(opName, buildMetadata);
|
kernelTransforms[buildMetadata.kernelName];
|
||||||
|
candidates.emplace_back(opName, buildMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (StringRef aliasKernelName : buildMetadata.aliasKernelNames) {
|
||||||
|
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
|
||||||
|
candidates.emplace_back(opName, buildMetadata);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult transformKernelCall(KernelCallOp kernelCall,
|
LogicalResult transformKernelCall(KernelCallOp kernelCall,
|
||||||
|
@ -229,6 +272,8 @@ public:
|
||||||
"arg arity mismatch");
|
"arg arity mismatch");
|
||||||
|
|
||||||
// Convert fixed return types.
|
// Convert fixed return types.
|
||||||
|
using PostConversionCallback = std::function<void()>;
|
||||||
|
SmallVector<PostConversionCallback, 4> postConversionCallbacks;
|
||||||
struct ConversionInfo {
|
struct ConversionInfo {
|
||||||
Value originalValue;
|
Value originalValue;
|
||||||
TypeConversion conversion;
|
TypeConversion conversion;
|
||||||
|
@ -241,25 +286,49 @@ public:
|
||||||
KVC flag = candidate.buildMetadata.getReturnConversion(i);
|
KVC flag = candidate.buildMetadata.getReturnConversion(i);
|
||||||
Value sourceValue = kernelCall.getResult(i);
|
Value sourceValue = kernelCall.getResult(i);
|
||||||
Type sourceMlirType = kernelCall.getResultTypes()[i];
|
Type sourceMlirType = kernelCall.getResultTypes()[i];
|
||||||
auto conversion = convertTorchReturnType(sourceTorchType, targetTorchType,
|
if (flag & KVC::kDropReturnAndAliasArg0) {
|
||||||
flag, sourceMlirType);
|
// Reduce result arity and alias any uses to arg0.
|
||||||
if (!conversion) {
|
if (kernelCall.args().empty()) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << " - Return type[" << i
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "] incompatible: source=" << sourceTorchType
|
<< " - Cannot alias arg0 (no arguments)\n");
|
||||||
<< ", target=" << targetTorchType
|
return failure();
|
||||||
<< ", flag=" << flag << "\n");
|
}
|
||||||
return failure();
|
Value arg0 = kernelCall.args()[0];
|
||||||
|
postConversionCallbacks.push_back(
|
||||||
|
[sourceValue, arg0]() { sourceValue.replaceAllUsesWith(arg0); });
|
||||||
|
} else {
|
||||||
|
// General, arity-preserving type conversion.
|
||||||
|
auto conversion = convertTorchReturnType(
|
||||||
|
sourceTorchType, targetTorchType, flag, sourceMlirType);
|
||||||
|
if (!conversion) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " - Return type[" << i << "] incompatible: source="
|
||||||
|
<< sourceTorchType << ", target=" << targetTorchType
|
||||||
|
<< ", flag=" << flag << "\n");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resultTypes.push_back(conversion->targetType);
|
||||||
|
resultConversions.push_back({sourceValue, std::move(*conversion)});
|
||||||
}
|
}
|
||||||
resultTypes.push_back(conversion->targetType);
|
|
||||||
resultConversions.push_back({sourceValue, std::move(*conversion)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert fixed arg types.
|
// Convert fixed arg types.
|
||||||
SmallVector<ConversionInfo, 4> operandInfos;
|
SmallVector<ConversionInfo, 4> operandInfos;
|
||||||
for (size_t i = 0; i < fixedArgArity; ++i) {
|
for (size_t i = 0, operandIndex = 0; i < fixedArgArity; ++i) {
|
||||||
|
// Drop this arg?
|
||||||
|
if (candidate.buildMetadata.argConversions[i] & KVC::kDrop)
|
||||||
|
continue;
|
||||||
|
if (kernelCall.getNumOperands() <= operandIndex) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< " - Arg operand " << i
|
||||||
|
<< " does not exist in kernel call (missing default?)\n");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal type conversion of the operand.
|
||||||
operandInfos.emplace_back();
|
operandInfos.emplace_back();
|
||||||
ConversionInfo &info = operandInfos.back();
|
ConversionInfo &info = operandInfos.back();
|
||||||
info.originalValue = kernelCall.getOperand(i);
|
info.originalValue = kernelCall.getOperand(operandIndex++);
|
||||||
Type sourceMlirType = info.originalValue.getType();
|
Type sourceMlirType = info.originalValue.getType();
|
||||||
auto conversion = convertTorchArgType(
|
auto conversion = convertTorchArgType(
|
||||||
/*sourceTorchType=*/sourceMetadata.argTypes[i],
|
/*sourceTorchType=*/sourceMetadata.argTypes[i],
|
||||||
|
@ -312,6 +381,10 @@ public:
|
||||||
origOpResultValue.replaceAllUsesWith(convertedValue);
|
origOpResultValue.replaceAllUsesWith(convertedValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Post conversion callbacks.
|
||||||
|
for (auto &callback : postConversionCallbacks)
|
||||||
|
callback();
|
||||||
|
|
||||||
// Done.
|
// Done.
|
||||||
rewriter.eraseOp(kernelCall);
|
rewriter.eraseOp(kernelCall);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
|
||||||
// CHECK-LABEL: "L0-_convolution-0": {
|
|
||||||
// CHECK-NEXT: "activation_in": 32768,
|
|
||||||
// CHECK-NEXT: "activation_out": 65536,
|
|
||||||
// CHECK-NEXT: "ops:+": 65536,
|
|
||||||
// CHECK-NEXT: "ops:MAC": 6422528,
|
|
||||||
// CHECK-NEXT: "parameters_in": 1584,
|
|
||||||
// CHECK-NEXT: "reads": 34352,
|
|
||||||
// CHECK-NEXT: "writes": 65536
|
|
||||||
|
|
||||||
module {
|
|
||||||
func @graph(%arg0: tensor<1x2x128x128xf32>, %arg1: tensor<16x2x7x7xf32>, %arg2: tensor<16xf32>) -> tensor<1x16x64x64xf32> {
|
|
||||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
|
|
||||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
|
|
||||||
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
|
|
||||||
%3 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
|
||||||
%4 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi64>} : () -> !aten.list<i32>
|
|
||||||
%5 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
|
||||||
%6 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
|
||||||
%7 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
|
|
||||||
%8 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
|
|
||||||
%9 = "aten._convolution"(%arg0, %arg1, %arg2, %0, %1, %2) : (tensor<1x2x128x128xf32>, tensor<16x2x7x7xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x64x64xf32>
|
|
||||||
"std.return"(%9) : (tensor<1x16x64x64xf32>) -> ()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,23 +0,0 @@
|
||||||
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
|
|
||||||
// CHECK-LABEL: "L0-convolution_backward_overrideable-0": {
|
|
||||||
// CHECK-NEXT: "activation_in": 5568,
|
|
||||||
// CHECK-NEXT: "grad": 5380,
|
|
||||||
// CHECK-NEXT: "ops:+": 768,
|
|
||||||
// CHECK-NEXT: "ops:MAC": 345600,
|
|
||||||
// CHECK-NEXT: "parameters_in": 576,
|
|
||||||
// CHECK-NEXT: "reads": 6144,
|
|
||||||
// CHECK-NEXT: "writes": 5380
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
|
|
||||||
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
|
|
||||||
// CHECK-CONVERSION-LABEL: @graph
|
|
||||||
module {
|
|
||||||
func @graph(%arg0: tensor<3x4x8x8xf32>, %arg1: tensor<3x16x10x10xf32>, %arg2: tensor<4x16x3x3xf32>) -> tensor<4x16x3x3xf32> {
|
|
||||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
|
|
||||||
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
|
||||||
%10:3 = "aten.convolution_backward_overrideable"(%arg0, %arg1, %arg2, %0, %1, %0, %2, %1, %3) {layer_name = "L5-convolution_backward_overrideable-0"} : (tensor<3x4x8x8xf32>, tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1, !aten.list<i32>, i32) -> (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>)
|
|
||||||
return %10#1 : tensor<4x16x3x3xf32>
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
|
|
||||||
// CHECK-CONVERSION-LABEL: @graph
|
|
||||||
|
|
||||||
module {
|
|
||||||
func @graph(%arg0: tensor<10xf32>, %arg1: tensor<128xf32>, %arg2: tensor<4x1x28x28xf32>, %arg3: tensor<32x1x3x3xf32>, %arg4: tensor<32xf32>, %arg5: tensor<64x32x3x3xf32>, %arg6: tensor<64xf32>, %arg7: tensor<128x9216xf32>, %arg8: tensor<10x128xf32>) -> tensor<4x10xf32> {
|
|
||||||
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
|
|
||||||
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
|
|
||||||
%4 = "aten.constant"() {type = "bool", value = true} : () -> i1
|
|
||||||
%5 = "aten._convolution"(%arg2, %arg3, %arg4, %0, %1, %0) {layer_name = "L0-_convolution-0"} : (tensor<4x1x28x28xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x32x26x26xf32>
|
|
||||||
%6 = "aten.relu"(%5) {layer_name = "L1-relu-0"} : (tensor<4x32x26x26xf32>) -> tensor<4x32x26x26xf32>
|
|
||||||
%7 = "aten._convolution"(%6, %arg5, %arg6, %0, %1, %0) {layer_name = "L2-_convolution-1"} : (tensor<4x32x26x26xf32>, tensor<64x32x3x3xf32>, tensor<64xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x64x24x24xf32>
|
|
||||||
%8 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%9:2 = "aten.max_pool2d_with_indices"(%7, %8, %8, %1, %0, %2) {layer_name = "L3-max_pool2d_with_indices-0"} : (tensor<4x64x24x24xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> (tensor<4x64x12x12xf32>, tensor<4x64x12x12xi64>)
|
|
||||||
%10 = "aten.constant"() {type = "List[i32]", value = dense<[4, 9216]> : vector<2xi32>} : () -> !aten.list<i32>
|
|
||||||
%11 = "aten.view"(%9#0, %10) {layer_name = "L4-view-0"} : (tensor<4x64x12x12xf32>, !aten.list<i32>) -> tensor<4x9216xf32>
|
|
||||||
%12 = "aten.t"(%arg7) {layer_name = "L5-t-0"} : (tensor<128x9216xf32>) -> tensor<9216x128xf32>
|
|
||||||
%13 = "aten.addmm"(%arg1, %11, %12, %3, %3) {layer_name = "L6-addmm-0"} : (tensor<128xf32>, tensor<4x9216xf32>, tensor<9216x128xf32>, i32, i32) -> tensor<4x128xf32>
|
|
||||||
%14 = "aten.relu"(%13) {layer_name = "L7-relu-1"} : (tensor<4x128xf32>) -> tensor<4x128xf32>
|
|
||||||
%15 = "aten.t"(%arg8) {layer_name = "L8-t-1"} : (tensor<10x128xf32>) -> tensor<128x10xf32>
|
|
||||||
%16 = "aten.addmm"(%arg0, %14, %15, %3, %3) {layer_name = "L9-addmm-1"} : (tensor<10xf32>, tensor<4x128xf32>, tensor<128x10xf32>, i32, i32) -> tensor<4x10xf32>
|
|
||||||
%17 = "aten._log_softmax"(%16, %3, %2) {layer_name = "L10-_log_softmax-0"} : (tensor<4x10xf32>, i32, i1) -> tensor<4x10xf32>
|
|
||||||
return %17 : tensor<4x10xf32>
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -13,3 +13,86 @@ func @graph(%arg0: !numpy.ndarray<*:?>, %arg1 : !numpy.ndarray<*:?>, %arg2 : si6
|
||||||
// CHECK: return %[[RESULT_MUT]]
|
// CHECK: return %[[RESULT_MUT]]
|
||||||
return %0 : !numpy.ndarray<*:?>
|
return %0 : !numpy.ndarray<*:?>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @nll_loss2d_forward
|
||||||
|
// Contains a Tensor? type mapped to None.
|
||||||
|
func @nll_loss2d_forward(
|
||||||
|
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
|
||||||
|
%arg1: !numpy.ndarray<[3,8,8]:i64>,
|
||||||
|
%arg2: !basicpy.NoneType,
|
||||||
|
%arg3: i64,
|
||||||
|
%arg4: i64) -> (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) {
|
||||||
|
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
|
||||||
|
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||||
|
// CHECK: %[[TOUTPUT:.*]], %[[TTOTAL_WEIGHT:.*]] = "aten.nll_loss2d_forward"(%[[TARG0]], %[[TARG1]], %arg2, %arg3, %arg4) : (tensor<3x4x8x8xf32>, tensor<3x8x8xi64>, !basicpy.NoneType, i64, i64) -> (tensor<f32>, tensor<f32>)
|
||||||
|
// CHECK: %[[AOUTPUT:.*]] = numpy.create_array_from_tensor %[[TOUTPUT]]
|
||||||
|
// CHECK: %[[ATOTAL_WEIGHT:.*]] = numpy.create_array_from_tensor %[[TTOTAL_WEIGHT]]
|
||||||
|
%0:2 = torch.kernel_call "aten::nll_loss2d_forward"
|
||||||
|
%arg0, %arg1, %arg2, %arg3, %arg4 :
|
||||||
|
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,8,8]:i64>, !basicpy.NoneType, i64, i64) ->
|
||||||
|
(!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>)
|
||||||
|
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor"]}
|
||||||
|
// CHECK: return %[[AOUTPUT]], %[[ATOTAL_WEIGHT]]
|
||||||
|
return %0#0, %0#1 : !numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @convolution
|
||||||
|
// Contains a Tensor?, bool, int and list types.
|
||||||
|
func @convolution(
|
||||||
|
%arg0: !numpy.ndarray<[3,16,10,10]:f32>, %arg1: !numpy.ndarray<[4,16,3,3]:f32>,
|
||||||
|
%arg2: !numpy.ndarray<[4]:f32>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType,
|
||||||
|
%arg5: !basicpy.ListType, %arg6: i1, %arg7: !basicpy.ListType, %arg8: i64) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||||
|
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
|
||||||
|
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||||
|
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
|
||||||
|
// CHECK: %[[TRESULT:.*]] = "aten.convolution"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> tensor<3x4x8x8xf32>
|
||||||
|
%0 = torch.kernel_call "aten::convolution"
|
||||||
|
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 :
|
||||||
|
(!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>,
|
||||||
|
!numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType,
|
||||||
|
!basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
||||||
|
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
|
return %0 : !numpy.ndarray<[3,4,8,8]:f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @convolution_backward
|
||||||
|
// Interesting because it has optional tensor returns.
|
||||||
|
func @convolution_backward(
|
||||||
|
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
|
||||||
|
%arg1: !numpy.ndarray<[3,16,10,10]:f32>,
|
||||||
|
%arg2: !numpy.ndarray<[4,16,3,3]:f32>,
|
||||||
|
%arg3: !basicpy.ListType,
|
||||||
|
%arg4: !basicpy.ListType,
|
||||||
|
%arg5: !basicpy.ListType,
|
||||||
|
%arg6: i1,
|
||||||
|
%arg7: !basicpy.ListType,
|
||||||
|
%arg8: i64,
|
||||||
|
%arg9: !basicpy.ListType) -> (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {
|
||||||
|
// CHECK: %[[GRAD_INPUT:.*]], %[[GRAD_WEIGHT:.*]], %[[GRAD_BIAS:.*]] = "aten.convolution_backward"
|
||||||
|
// Note that this kernel call masks out the input gradients, which will return as NoneType.
|
||||||
|
%0:3 = torch.kernel_call "aten::convolution_backward"
|
||||||
|
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 :
|
||||||
|
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64, !basicpy.ListType) ->
|
||||||
|
(!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {sigArgTypes = ["Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor", "Tensor"]}
|
||||||
|
// CHECK: %[[AGRAD_WEIGHT:.*]] = numpy.create_array_from_tensor %[[GRAD_WEIGHT]]
|
||||||
|
// CHECK: %[[AGRAD_BIAS:.*]] = numpy.create_array_from_tensor %[[GRAD_BIAS]]
|
||||||
|
// Key thing: The return returns the raw NoneType from the masked input gradient
|
||||||
|
// and it does not get converted to an array.
|
||||||
|
// CHECK: return %[[GRAD_INPUT]], %[[AGRAD_WEIGHT]], %[[AGRAD_BIAS]]
|
||||||
|
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @copy_inplace
|
||||||
|
// Mutable/in-place op conversion, dropping result.
|
||||||
|
func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {
|
||||||
|
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||||
|
// CHECK: "aten.copy.inplace"(%arg0, %[[TARG1]]) : (!numpy.ndarray<[4]:f32>, tensor<4xf32>) -> ()
|
||||||
|
%0 = torch.kernel_call "aten::copy_" %arg0, %arg1 : (!numpy.ndarray<[4]:f32>, !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {sigArgTypes = ["Tensor", "Tensor", "bool"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %0 : !numpy.ndarray<[4]:f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue