diff --git a/src/class/mtp/mtp_device.c b/src/class/mtp/mtp_device.c index 13f868d14..59096e476 100644 --- a/src/class/mtp/mtp_device.c +++ b/src/class/mtp/mtp_device.c @@ -92,6 +92,7 @@ typedef struct { uint8_t itf_num; uint8_t ep_in; uint8_t ep_out; + uint8_t ep_event; uint8_t ep_sz_fs; // Bulk Only Transfer (BOT) Protocol @@ -194,50 +195,47 @@ static bool prepare_new_command(mtpd_interface_t* p_mtp) { return usbd_edpt_xfer(p_mtp->rhport, p_mtp->ep_out, _mtpd_epbuf.buf, CFG_TUD_MTP_EP_BUFSIZE, false); } -static bool mtpd_data_xfer(mtp_container_info_t* p_container, uint8_t ep_addr) { - mtpd_interface_t* p_mtp = &_mtpd_itf; +bool tud_mtp_data_send(mtp_container_info_t *p_container) { + mtpd_interface_t *p_mtp = &_mtpd_itf; if (p_mtp->phase == MTP_PHASE_COMMAND) { // 1st data block: header + payload p_mtp->phase = MTP_PHASE_DATA; p_mtp->xferred_len = 0; + p_mtp->total_len = p_container->header->len; - if (tu_edpt_dir(ep_addr) == TUSB_DIR_IN) { - p_mtp->total_len = p_container->header->len; - p_container->header->type = MTP_CONTAINER_TYPE_DATA_BLOCK; - p_container->header->transaction_id = p_mtp->command.header.transaction_id; - p_mtp->io_header = *p_container->header; // save header for subsequent data - } else { - p_mtp->total_len = p_container->header->len; - } - } else { - // subsequent data block: payload only - TU_ASSERT(p_mtp->phase == MTP_PHASE_DATA); + p_container->header->type = MTP_CONTAINER_TYPE_DATA_BLOCK; + p_container->header->transaction_id = p_mtp->command.header.transaction_id; + p_mtp->io_header = *p_container->header; // save header for subsequent data } - uint16_t xact_len = 0; - if (tu_edpt_dir(ep_addr) == TUSB_DIR_IN) { - xact_len = (uint16_t) tu_min32(p_mtp->total_len - p_mtp->xferred_len, CFG_TUD_MTP_EP_BUFSIZE); - } else { - // Use fixed transfer length to make ZLP handling easier - xact_len = CFG_TUD_MTP_EP_BUFSIZE; - } + const uint16_t xact_len = (uint16_t)tu_min32(p_mtp->total_len - p_mtp->xferred_len, CFG_TUD_MTP_EP_BUFSIZE); - TU_LOG_DRV(" MTP Data Xfer %s: xferred_len/total_len=%lu/%lu, xact_len=%u\r\n", - (tu_edpt_dir(ep_addr) == TUSB_DIR_IN) ? "IN" : "OUT", - p_mtp->xferred_len, p_mtp->total_len, xact_len); + TU_LOG_DRV(" MTP Data IN: xferred_len/total_len=%lu/%lu, xact_len=%u\r\n", p_mtp->xferred_len, p_mtp->total_len, + xact_len); if (xact_len) { - TU_VERIFY(usbd_edpt_claim(p_mtp->rhport, ep_addr)); - TU_ASSERT(usbd_edpt_xfer(p_mtp->rhport, ep_addr, _mtpd_epbuf.buf, xact_len, false)); + TU_VERIFY(usbd_edpt_claim(p_mtp->rhport, p_mtp->ep_in)); + TU_ASSERT(usbd_edpt_xfer(p_mtp->rhport, p_mtp->ep_in, _mtpd_epbuf.buf, xact_len, false)); } return true; } -bool tud_mtp_data_send(mtp_container_info_t* p_container) { - return mtpd_data_xfer(p_container, _mtpd_itf.ep_in); -} +bool tud_mtp_data_receive(mtp_container_info_t *p_container) { + mtpd_interface_t *p_mtp = &_mtpd_itf; + if (p_mtp->phase == MTP_PHASE_COMMAND) { + // 1st data block: header + payload + p_mtp->phase = MTP_PHASE_DATA; + p_mtp->xferred_len = 0; + p_mtp->total_len = p_container->header->len; + } -bool tud_mtp_data_receive(mtp_container_info_t* p_container) { - return mtpd_data_xfer(p_container, _mtpd_itf.ep_out); + // up to buffer size since 1st packet (with header) may also contain payload + const uint16_t xact_len = CFG_TUD_MTP_EP_BUFSIZE; + + TU_LOG_DRV(" MTP Data OUT: xferred_len/total_len=%lu/%lu, xact_len=%u\r\n", p_mtp->xferred_len, p_mtp->total_len, + xact_len); + TU_VERIFY(usbd_edpt_claim(p_mtp->rhport, p_mtp->ep_out)); + TU_ASSERT(usbd_edpt_xfer(p_mtp->rhport, p_mtp->ep_out, _mtpd_epbuf.buf, xact_len, false)); + return true; } bool tud_mtp_response_send(mtp_container_info_t* p_container) { @@ -434,22 +432,25 @@ bool mtpd_xfer_cb(uint8_t rhport, uint8_t ep_addr, xfer_result_t event, uint32_t cb_data.total_xferred_bytes = p_mtp->xferred_len; const bool is_data_in = (ep_addr == p_mtp->ep_in); - const uint16_t bulk_mps = (tud_speed_get() == TUSB_SPEED_HIGH) ? 512 : p_mtp->ep_sz_fs; // For IN endpoint, threshold is bulk max packet size // For OUT endpoint, threshold is endpoint buffer size, since we always queue fixed size - const uint16_t threshold = is_data_in ? bulk_mps : CFG_TUD_MTP_EP_BUFSIZE; + uint16_t threshold; + if (is_data_in) { + threshold = (p_mtp->ep_sz_fs > 0) ? p_mtp->ep_sz_fs : 512; // full speed bulk if set + } else { + threshold = CFG_TUD_MTP_EP_BUFSIZE; + } // Check completion: ZLP, short packet, or total length reached - bool is_complete = (xferred_bytes == 0 || - xferred_bytes < threshold || - p_mtp->xferred_len >= p_mtp->total_len); + const bool is_complete = + (xferred_bytes == 0 || xferred_bytes < threshold || p_mtp->xferred_len >= p_mtp->total_len); TU_LOG_DRV(" MTP Data %s CB: xferred_bytes=%lu, xferred_len/total_len=%lu/%lu, is_complete=%d\r\n", is_data_in ? "IN" : "OUT", xferred_bytes, p_mtp->xferred_len, p_mtp->total_len, is_complete ? 1 : 0); // Send/queue ZLP if packet is full-sized but transfer is complete if (is_complete && xferred_bytes > 0 && !(xferred_bytes & (threshold - 1))) { - TU_LOG_DRV(" QUEUE ZLP\r\n"); + TU_LOG_DRV(" queue ZLP\r\n"); TU_VERIFY(usbd_edpt_claim(p_mtp->rhport, ep_addr)); TU_ASSERT(usbd_edpt_xfer(p_mtp->rhport, ep_addr, NULL, 0, false)); return true; diff --git a/test/hil/hil_test.py b/test/hil/hil_test.py index b2e883119..dfea9612f 100755 --- a/test/hil/hil_test.py +++ b/test/hil/hil_test.py @@ -151,7 +151,7 @@ def read_disk_file(uid, lun, fname): def open_mtp_dev(uid): mtp = MTP() # MTP seems to take a while to enumerate - timeout = 2*ENUM_TIMEOUT + timeout = 2 * ENUM_TIMEOUT while timeout > 0: # run_cmd(f"gio mount -u mtp://TinyUsb_TinyUsb_Device_{uid}/") for raw in mtp.detect_devices(): @@ -617,13 +617,13 @@ def test_device_mtp(board): # device tests # note don't test 2 examples with cdc or 2 msc next to each other device_tests = [ - 'device/cdc_dual_ports', - 'device/dfu', - 'device/cdc_msc', - 'device/dfu_runtime', - 'device/cdc_msc_freertos', - 'device/hid_boot_interface', - # 'device/mtp' + # 'device/cdc_dual_ports', + # 'device/dfu', + # 'device/cdc_msc', + # 'device/dfu_runtime', + # 'device/cdc_msc_freertos', + # 'device/hid_boot_interface', + 'device/mtp' ] dual_tests = [