Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,20 +484,32 @@ def c_extract(self, name, sub, check_input=True, **kwargs):
def c_sync(self, name, sub):
specs = self.dtype_specs()
fail = sub["fail"]
dtype = specs[1]
cls = specs[2]
(np_dtype, _c_dtype, _cls_name) = specs
np_dtype_num = np.dtype(np_dtype).num

return f"""
Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls});

PyArray_Descr* {name}_descr = PyArray_DescrFromType({np_dtype_num}); // {np_dtype}
if (!{name}_descr) {{
PyErr_Format(PyExc_RuntimeError, "Could not get descriptor for {np_dtype_num}={np_dtype}");
{fail}
}}

// PyArray_Scalar creates a new scalar object by copying data from the pointer &{name}
py_{name} = PyArray_Scalar(&{name}, {name}_descr, NULL);

// Clean up the descriptor reference (PyArray_DescrFromType returns a new ref)
Py_DECREF({name}_descr);

if (!py_{name})
{{
Py_XINCREF(Py_None);
py_{name} = Py_None;
PyErr_Format(PyExc_MemoryError,
"Instantiation of new Python scalar failed ({dtype})");
"Instantiation of new Python NumPy scalar failed ({np_dtype_num}={np_dtype})");
{fail}
}}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
"""

def c_cleanup(self, name, sub):
Expand Down Expand Up @@ -762,7 +774,7 @@ def c_init_code(self, **kwargs):
return ["import_array();"]

def c_code_cache_version(self):
return (14, np.__version__)
return (15, np.__version__)

def get_shape_info(self, obj):
return obj.itemsize
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
def perform(self, node, inputs, outputs):
in1, in2, full_mode = inputs

# TODO: Why is .item() needed?
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
mode = "full" if full_mode else "valid"
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)


Expand Down
7 changes: 7 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,13 @@ def test_ScalarFromTensor(cast_policy):
scalar_from_tensor(vector())


def test_bool_scalar_from_tensor():
x = scalar("x", dtype="bool")
fn = function([x], scalar_from_tensor(x))
assert fn(np.array(True, dtype=bool))
assert not fn(np.array(False, dtype=bool))


def test_op_cache():
# TODO: What is this actually testing?
# trigger bug in ticket #162
Expand Down
Loading